1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package violajones;
26
27
28 import jdk.incubator.code.Op;
29 import org.w3c.dom.Element;
30 import org.w3c.dom.NodeList;
31 import org.xml.sax.SAXException;
32 import violajones.ifaces.Cascade;
33
34 import javax.xml.parsers.DocumentBuilderFactory;
35 import javax.xml.parsers.ParserConfigurationException;
36 import java.io.IOException;
37 import java.io.InputStream;
38 import java.util.ArrayList;
39 import java.util.List;
40 import java.util.Locale;
41 import java.util.Optional;
42 import java.util.Scanner;
43 import java.util.function.Consumer;
44 import java.util.function.Predicate;
45
46 public class XMLHaarCascadeModel implements Cascade {
47
48 private final Element cascadeElement;
49
50 static Optional<Element> selectChild(Element element, Predicate<Element> predicate) {
51 NodeList nodes = element.getChildNodes();
52 for (int i = 0; i < nodes.getLength(); i++) {
53 if (nodes.item(i) instanceof Element childElement && predicate.test(childElement)) {
54 return Optional.of(childElement);
55 }
56 }
57 return Optional.empty();
58 }
59
60 static Optional<Element> selectChild(Element element, final String name) {
61 return selectChild(element, (e) -> e.getNodeName().equals(name));
62 }
63
64 static void forEachElement(Element element, Predicate<Element> predicate, Consumer<Element> consumer) {
65 NodeList nodes = element.getChildNodes();
66 for (int i = 0; i < nodes.getLength(); i++) {
67 if (nodes.item(i) instanceof Element xmle && predicate.test(xmle)) {
68 consumer.accept(xmle);
69 }
70 }
71 }
72
73 static float getFloat(Element element, final String name) {
74 return selectChild(element, name).map(value -> Float.parseFloat(value.getTextContent())).orElse(0f);
75 }
76
77 static short getShort(Element element, final String name) {
78 return selectChild(element, name).map(value -> Short.parseShort(value.getTextContent())).orElse((short) 0);
79 }
80
81 final public List<Feature> features = new ArrayList<>();
82 final public List<Tree> trees = new ArrayList<>();
83 final public List<Stage> stages = new ArrayList<>();
84
85 @Override
86 public Cascade.Feature feature(long idx) {
87 return features.get((int) idx);
88 }
89
90 @Override
91 public int featureCount() {
92 return features.size();
93 }
94
95 @Override
96 public Cascade.Stage stage(long idx) {
97 return stages.get((int) idx);
98 }
99
100 @Override
101 public int stageCount() {
102 return stages.size();
103 }
104
105 @Override
106 public Cascade.Tree tree(long idx) {
107 return trees.get((int) idx);
108 }
109
110 @Override
111 public int treeCount() {
112 return trees.size();
113 }
114
115 @Override
116 public int width() {
117 return width;
118 }
119
120 @Override
121 public void width(int width) {
122 throw new IllegalStateException("void width(int width) unimplemented ");
123 }
124
125 @Override
126 public int height() {
127 return height;
128 }
129
130 @Override
131 public void height(int height) {
132 throw new IllegalStateException("void height(int height) unimplemented ");
133 }
134
135 static public class Feature implements Cascade.Feature {
136 private final Element featureElement;
137 private final Tree tree;
138
139 @Override
140 public int id() {
141 return id;
142 }
143
144 @Override
145 public float threshold() {
146 return threshold;
147 }
148
149 @Override
150 public void id(int id) {
151 throw new IllegalStateException("void id(int id) unimplemented ");
152 }
153
154 @Override
155 public void threshold(float threshold) {
156 throw new IllegalStateException("void threshold(float threshold) unimplemented ");
157 }
158
159 @Override
160 public Cascade.Feature.LinkOrValue left() {
161 return left;
162 }
163
164 @Override
165 public Cascade.Feature.LinkOrValue right() {
166 return right;
167 }
168
169 @Override
170 public Cascade.Feature.Rect rect(long idx) {
171 if (rects.length > idx) {
172 return rects[(int) idx];
173 } else {
174 return null;
175 }
176 }
177
178
179 static public class LinkOrValue implements Cascade.Feature.LinkOrValue, Cascade.Feature.LinkOrValue.Anon {
180 private final boolean hasValue;
181
182
183 private short featureId;
184 private float value;
185
186 public LinkOrValue(short featureId) {
187 this.featureId = featureId;
188 hasValue = false;
189 }
190
191 public LinkOrValue(float value) {
192 this.value = value;
193 hasValue = true;
194 }
195
196 @Override
197 public boolean hasValue() {
198 return hasValue;
199 }
200
201 @Override
202 public void hasValue(boolean hasValue) {
203 throw new IllegalStateException("void LinkOrValue(boolean ) unimplemented ");
204 }
205
206 @Override
207 public Anon anon() {
208 return this;
209 }
210
211 @Override
212 public int featureId() {
213 return featureId;
214 }
215
216 @Override
217 public float value() {
218 return value;
219 }
220
221 @Override
222 public void featureId(int featureId) {
223 throw new IllegalStateException("void featureId(int featureId) unimplemented ");
224 }
225
226 @Override
227 public void value(float value) {
228 throw new IllegalStateException("void value(float value) unimplemented ");
229 }
230 }
231
232 static public class Rect implements Cascade.Feature.Rect {
233 private final Feature feature;
234 private final Element rectElement;
235 private final byte x, y, width, height;
236 private final float weight;
237
238 public Rect(Feature feature, Element rectElement) {
239 this.feature = feature;
240 this.rectElement = rectElement;
241
242 Scanner rectScanner = new Scanner(this.rectElement.getTextContent());
243 rectScanner.useLocale(Locale.US);
244 this.x = rectScanner.nextByte();
245 this.y = rectScanner.nextByte();
246 this.width = rectScanner.nextByte();
247 this.height = rectScanner.nextByte();
248 this.weight = rectScanner.nextFloat();
249 }
250
251 @Override
252 public byte x() {
253 return x;
254 }
255
256 @Override
257 public byte y() {
258 return y;
259 }
260
261 @Override
262 public byte width() {
263 return width;
264 }
265
266 @Override
267 public byte height() {
268 return height;
269 }
270
271 @Override
272 public float weight() {
273 return weight;
274 }
275
276 @Override
277 public void x(byte x) {
278 throw new IllegalStateException("void x(byte x) unimplemented ");
279 }
280
281 @Override
282 public void y(byte y) {
283 throw new IllegalStateException("void y(byte y) unimplemented ");
284 }
285
286 @Override
287 public void width(byte width) {
288 throw new IllegalStateException("void width(byte width) unimplemented ");
289 }
290
291 @Override
292 public void height(byte height) {
293 throw new IllegalStateException("void height(byte height) unimplemented ");
294 }
295
296 @Override
297 public void weight(float height) {
298 throw new IllegalStateException("void weight(float weight) unimplemented ");
299 }
300 }
301
302 private final short id;
303 int rectCount;
304 public final Rect[] rects;
305 private final float threshold;
306
307 final public LinkOrValue left;
308 final public LinkOrValue right;
309
310
311 public Feature(Tree tree, Element featureElement, int id) {
312 this.tree = tree;
313 this.featureElement = featureElement;
314 this.id = (short) id;
315 this.rectCount = 0;
316 this.rects = new Rect[3];
317 this.threshold = getFloat(this.featureElement, "threshold");
318
319 left = (selectChild(this.featureElement, "left_val")).isPresent()
320 ? new Feature.LinkOrValue(getFloat(this.featureElement, "left_val"))
321 : new Feature.LinkOrValue(getShort(this.featureElement, "left_node"));
322 right = (selectChild(this.featureElement, "right_val")).isPresent()
323 ? new Feature.LinkOrValue(getFloat(this.featureElement, "right_val"))
324 : new Feature.LinkOrValue(getShort(this.featureElement, "right_node"));
325
326 selectChild(this.featureElement, "feature").flatMap(featureXML -> selectChild(featureXML, "rects")).ifPresent(rectsXML -> {
327 forEachElement(rectsXML, e -> e.getNodeName().equals("_"),
328 (rectXMLElement) -> rects[this.rectCount++] = new Rect(this, rectXMLElement)
329 );
330 });
331 }
332 }
333
334 static public class Tree implements Cascade.Tree {
335 private final Stage stage;
336 final Element treeElement;
337 final int id;
338
339 @Override
340 public void id(int id) {
341 throw new IllegalStateException("void id(int id) unimplemented ");
342 }
343
344 @Override
345 public void firstFeatureId(short firstFeatureId) {
346 throw new IllegalStateException("void firstFeatureId(short firstFeatureId) unimplemented ");
347 }
348
349 @Override
350 public void featureCount(short featureCount) {
351 throw new IllegalStateException("void featureCount(short featureCount) unimplemented ");
352 }
353
354 @Override
355 public int id() {
356 return id;
357 }
358
359 @Override
360 public short firstFeatureId() {
361 return firstFeatureId;
362 }
363
364 @Override
365 public short featureCount() {
366 return featureCount;
367 }
368
369 public short featureCount;
370 public short firstFeatureId = -1;
371
372 public Tree(Stage stage, Element treeElement, int id) {
373 this.stage = stage;
374 this.treeElement = treeElement;
375 this.id = id;
376 forEachElement(treeElement, e -> e.getNodeName().equals("_"),
377 featureXMLElement -> {
378 Feature feature = new Feature(this, featureXMLElement, stage.haarCascade.features.size());
379 stage.haarCascade.features.add(feature);
380 if (firstFeatureId == -1) {
381 firstFeatureId = feature.id;
382 }
383 featureCount = (short) (feature.id - firstFeatureId + 1);
384 });
385 }
386 }
387
388 public static class Stage implements Cascade.Stage {
389
390 private final XMLHaarCascadeModel haarCascade;
391 private final Element stageElement;
392
393 @Override
394 public float threshold() {
395 return threshold;
396 }
397
398 @Override
399 public short firstTreeId() {
400 return firstTreeId;
401 }
402
403 @Override
404 public short treeCount() {
405 return treeCount;
406 }
407
408
409 @Override
410 public int id() {
411 return id;
412 }
413
414 @Override
415 public void id(int id) {
416 throw new IllegalStateException("void id(int id) unimplemented ");
417 }
418
419 @Override
420 public void threshold(float threshold) {
421 throw new IllegalStateException("void threshold(float threshold) unimplemented ");
422 }
423
424 @Override
425 public void firstTreeId(short firstTreeId) {
426 throw new IllegalStateException("void firstTreeId(short firstTreeId) unimplemented ");
427 }
428
429 @Override
430 public void treeCount(short treeCount) {
431 throw new IllegalStateException("void treeCount(short treeCount) unimplemented ");
432 }
433
434 final public int id;
435
436 final public float threshold;
437 public short firstTreeId = -1;
438 public short treeCount;
439
440 public Stage(XMLHaarCascadeModel haarCascade, Element stageElement, int id) {
441 this.haarCascade = haarCascade;
442 this.stageElement = stageElement;
443 this.id = id;
444 this.threshold = getFloat(this.stageElement, "stage_threshold");
445 selectChild(this.stageElement, "trees").ifPresent(treeXML -> {
446 forEachElement(treeXML, e -> e.getNodeName().equals("_"), treeXMLElement -> {
447 Tree tree = new Tree(this, treeXMLElement, haarCascade.trees.size());
448 haarCascade.trees.add(tree);
449 if (firstTreeId == -1) {
450 firstTreeId = (short) tree.id;
451 }
452 treeCount = (short) (tree.id - firstTreeId + 1);
453 }
454 );
455 });
456
457 }
458 }
459
460 final public int width;
461 final public int height;
462
463 public static XMLHaarCascadeModel load(InputStream is) throws IOException, SAXException, ParserConfigurationException {
464 if (is == null) {
465 throw new IllegalArgumentException("input == null!");
466 }
467 org.w3c.dom.Document doc = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(is);
468 doc.getDocumentElement().normalize();
469 Element root = doc.getDocumentElement();
470 Element cascadeElement = selectChild(root, (e) -> e.hasAttribute("type_id")).get();
471
472 return new XMLHaarCascadeModel(cascadeElement);
473 }
474
475 XMLHaarCascadeModel(Element cascadeElement) {
476 this.cascadeElement = cascadeElement;
477 Optional<Element> size = selectChild(cascadeElement, "size");
478 if (size.isPresent()) {
479 Scanner sizeScanner = new Scanner(size.get().getTextContent());
480 this.width = sizeScanner.nextInt();
481 this.height = sizeScanner.nextInt();
482 }else{
483 if (selectChild(cascadeElement,"width") instanceof Optional<Element> optWidth && optWidth.isPresent()
484 && selectChild(cascadeElement,"height") instanceof Optional<Element> optHeight && optHeight.isPresent()
485 ){
486 this.width = Integer.parseInt(optWidth.get().getTextContent());
487 this.height = Integer.parseInt(optHeight.get().getTextContent());
488 //System.out.println("height width = "+this.width + " "+this.height);
489 }else{
490 throw new IllegalStateException("No width/height or size element in cascade ");
491 }
492 }
493 selectChild(cascadeElement, "stages").ifPresent(stagesXML ->
494 forEachElement(stagesXML, e -> e.getNodeName().equals("_"),
495 (stageXMLElement) ->
496 stages.add(new Stage(this, stageXMLElement, stages.size()))
497 )
498 );
499
500
501 }
502 }