1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package violajones;
 26 
 27 
 28 import jdk.incubator.code.Op;
 29 import org.w3c.dom.Element;
 30 import org.w3c.dom.NodeList;
 31 import org.xml.sax.SAXException;
 32 import violajones.ifaces.Cascade;
 33 
 34 import javax.xml.parsers.DocumentBuilderFactory;
 35 import javax.xml.parsers.ParserConfigurationException;
 36 import java.io.IOException;
 37 import java.io.InputStream;
 38 import java.util.ArrayList;
 39 import java.util.List;
 40 import java.util.Locale;
 41 import java.util.Optional;
 42 import java.util.Scanner;
 43 import java.util.function.Consumer;
 44 import java.util.function.Predicate;
 45 
 46 public class XMLHaarCascadeModel implements Cascade {
 47 
 48     private final Element cascadeElement;
 49 
 50     static Optional<Element> selectChild(Element element, Predicate<Element> predicate) {
 51         NodeList nodes = element.getChildNodes();
 52         for (int i = 0; i < nodes.getLength(); i++) {
 53             if (nodes.item(i) instanceof Element childElement && predicate.test(childElement)) {
 54                 return Optional.of(childElement);
 55             }
 56         }
 57         return Optional.empty();
 58     }
 59 
 60     static Optional<Element> selectChild(Element element, final String name) {
 61         return selectChild(element, (e) -> e.getNodeName().equals(name));
 62     }
 63 
 64     static void forEachElement(Element element, Predicate<Element> predicate, Consumer<Element> consumer) {
 65         NodeList nodes = element.getChildNodes();
 66         for (int i = 0; i < nodes.getLength(); i++) {
 67             if (nodes.item(i) instanceof Element xmle && predicate.test(xmle)) {
 68                 consumer.accept(xmle);
 69             }
 70         }
 71     }
 72 
 73     static float getFloat(Element element, final String name) {
 74         return selectChild(element, name).map(value -> Float.parseFloat(value.getTextContent())).orElse(0f);
 75     }
 76 
 77     static short getShort(Element element, final String name) {
 78         return selectChild(element, name).map(value -> Short.parseShort(value.getTextContent())).orElse((short) 0);
 79     }
 80 
 81     final public List<Feature> features = new ArrayList<>();
 82     final public List<Tree> trees = new ArrayList<>();
 83     final public List<Stage> stages = new ArrayList<>();
 84 
 85     @Override
 86     public Cascade.Feature feature(long idx) {
 87         return features.get((int) idx);
 88     }
 89 
 90     @Override
 91     public int featureCount() {
 92         return features.size();
 93     }
 94 
 95     @Override
 96     public Cascade.Stage stage(long idx) {
 97         return stages.get((int) idx);
 98     }
 99 
100     @Override
101     public int stageCount() {
102         return stages.size();
103     }
104 
105     @Override
106     public Cascade.Tree tree(long idx) {
107         return trees.get((int) idx);
108     }
109 
110     @Override
111     public int treeCount() {
112         return trees.size();
113     }
114 
115     @Override
116     public int width() {
117         return width;
118     }
119 
120     @Override
121     public void width(int width) {
122         throw new IllegalStateException("void width(int width) unimplemented ");
123     }
124 
125     @Override
126     public int height() {
127         return height;
128     }
129 
130     @Override
131     public void height(int height) {
132         throw new IllegalStateException("void height(int height) unimplemented ");
133     }
134 
135     static public class Feature implements Cascade.Feature {
136         private final Element featureElement;
137         private final Tree tree;
138 
139         @Override
140         public int id() {
141             return id;
142         }
143 
144         @Override
145         public float threshold() {
146             return threshold;
147         }
148 
149         @Override
150         public void id(int id) {
151             throw new IllegalStateException("void id(int id) unimplemented ");
152         }
153 
154         @Override
155         public void threshold(float threshold) {
156             throw new IllegalStateException("void threshold(float threshold) unimplemented ");
157         }
158 
159         @Override
160         public Cascade.Feature.LinkOrValue left() {
161             return left;
162          }
163 
164         @Override
165         public Cascade.Feature.LinkOrValue right() {
166             return right;
167         }
168 
169         @Override
170         public Cascade.Feature.Rect rect(long idx) {
171             if (rects.length > idx) {
172                 return rects[(int) idx];
173             } else {
174                 return null;
175             }
176         }
177 
178 
179         static public class LinkOrValue implements Cascade.Feature.LinkOrValue, Cascade.Feature.LinkOrValue.Anon {
180             private final boolean hasValue;
181 
182 
183             private short featureId;
184             private float value;
185 
186             public LinkOrValue(short featureId) {
187                 this.featureId = featureId;
188                 hasValue = false;
189             }
190 
191             public LinkOrValue(float value) {
192                 this.value = value;
193                 hasValue = true;
194             }
195 
196             @Override
197             public boolean hasValue() {
198                 return hasValue;
199             }
200 
201             @Override
202             public void hasValue(boolean hasValue) {
203                 throw new IllegalStateException("void LinkOrValue(boolean ) unimplemented ");
204             }
205 
206             @Override
207             public Anon anon() {
208                 return this;
209             }
210 
211             @Override
212             public int featureId() {
213                 return featureId;
214             }
215 
216             @Override
217             public float value() {
218                 return value;
219             }
220 
221             @Override
222             public void featureId(int featureId) {
223                 throw new IllegalStateException("void featureId(int featureId) unimplemented ");
224             }
225 
226             @Override
227             public void value(float value) {
228                 throw new IllegalStateException("void value(float value) unimplemented ");
229             }
230         }
231 
232         static public class Rect implements Cascade.Feature.Rect {
233             private final Feature feature;
234             private final Element rectElement;
235             private final byte x, y, width, height;
236             private final float weight;
237 
238             public Rect(Feature feature, Element rectElement) {
239                 this.feature = feature;
240                 this.rectElement = rectElement;
241 
242                 Scanner rectScanner = new Scanner(this.rectElement.getTextContent());
243                 rectScanner.useLocale(Locale.US);
244                 this.x = rectScanner.nextByte();
245                 this.y = rectScanner.nextByte();
246                 this.width = rectScanner.nextByte();
247                 this.height = rectScanner.nextByte();
248                 this.weight = rectScanner.nextFloat();
249             }
250 
251             @Override
252             public byte x() {
253                 return x;
254             }
255 
256             @Override
257             public byte y() {
258                 return y;
259             }
260 
261             @Override
262             public byte width() {
263                 return width;
264             }
265 
266             @Override
267             public byte height() {
268                 return height;
269             }
270 
271             @Override
272             public float weight() {
273                 return weight;
274             }
275 
276             @Override
277             public void x(byte x) {
278                 throw new IllegalStateException("void x(byte x) unimplemented ");
279             }
280 
281             @Override
282             public void y(byte y) {
283                 throw new IllegalStateException("void y(byte y) unimplemented ");
284             }
285 
286             @Override
287             public void width(byte width) {
288                 throw new IllegalStateException("void width(byte width) unimplemented ");
289             }
290 
291             @Override
292             public void height(byte height) {
293                 throw new IllegalStateException("void height(byte height) unimplemented ");
294             }
295 
296             @Override
297             public void weight(float height) {
298                 throw new IllegalStateException("void weight(float weight) unimplemented ");
299             }
300         }
301 
302         private final short id;
303         int rectCount;
304         public final Rect[] rects;
305         private final float threshold;
306 
307         final public LinkOrValue left;
308         final public LinkOrValue right;
309 
310 
311         public Feature(Tree tree, Element featureElement, int id) {
312             this.tree = tree;
313             this.featureElement = featureElement;
314             this.id = (short) id;
315             this.rectCount = 0;
316             this.rects = new Rect[3];
317             this.threshold = getFloat(this.featureElement, "threshold");
318 
319             left = (selectChild(this.featureElement, "left_val")).isPresent()
320                     ? new Feature.LinkOrValue(getFloat(this.featureElement, "left_val"))
321                     : new Feature.LinkOrValue(getShort(this.featureElement, "left_node"));
322             right = (selectChild(this.featureElement, "right_val")).isPresent()
323                     ? new Feature.LinkOrValue(getFloat(this.featureElement, "right_val"))
324                     : new Feature.LinkOrValue(getShort(this.featureElement, "right_node"));
325 
326             selectChild(this.featureElement, "feature").flatMap(featureXML -> selectChild(featureXML, "rects")).ifPresent(rectsXML -> {
327                 forEachElement(rectsXML, e -> e.getNodeName().equals("_"),
328                         (rectXMLElement) -> rects[this.rectCount++] = new Rect(this, rectXMLElement)
329                 );
330             });
331         }
332     }
333 
334     static public class Tree implements Cascade.Tree {
335         private final Stage stage;
336         final Element treeElement;
337         final int id;
338 
339         @Override
340         public void id(int id) {
341             throw new IllegalStateException("void id(int  id) unimplemented ");
342         }
343 
344         @Override
345         public void firstFeatureId(short firstFeatureId) {
346             throw new IllegalStateException("void firstFeatureId(short  firstFeatureId) unimplemented ");
347         }
348 
349         @Override
350         public void featureCount(short featureCount) {
351             throw new IllegalStateException("void featureCount(short  featureCount) unimplemented ");
352         }
353 
354         @Override
355         public int id() {
356             return id;
357         }
358 
359         @Override
360         public short firstFeatureId() {
361             return firstFeatureId;
362         }
363 
364         @Override
365         public short featureCount() {
366             return featureCount;
367         }
368 
369         public short featureCount;
370         public short firstFeatureId = -1;
371 
372         public Tree(Stage stage, Element treeElement, int id) {
373             this.stage = stage;
374             this.treeElement = treeElement;
375             this.id = id;
376             forEachElement(treeElement, e -> e.getNodeName().equals("_"),
377                     featureXMLElement -> {
378                         Feature feature = new Feature(this, featureXMLElement, stage.haarCascade.features.size());
379                         stage.haarCascade.features.add(feature);
380                         if (firstFeatureId == -1) {
381                             firstFeatureId = feature.id;
382                         }
383                         featureCount = (short) (feature.id - firstFeatureId + 1);
384                     });
385         }
386     }
387 
388     public static class Stage implements Cascade.Stage {
389 
390         private final XMLHaarCascadeModel haarCascade;
391         private final Element stageElement;
392 
393         @Override
394         public float threshold() {
395             return threshold;
396         }
397 
398         @Override
399         public short firstTreeId() {
400             return firstTreeId;
401         }
402 
403         @Override
404         public short treeCount() {
405             return treeCount;
406         }
407 
408 
409         @Override
410         public int id() {
411             return id;
412         }
413 
414         @Override
415         public void id(int id) {
416             throw new IllegalStateException("void id(int  id) unimplemented ");
417         }
418 
419         @Override
420         public void threshold(float threshold) {
421             throw new IllegalStateException("void threshold(float  threshold) unimplemented ");
422         }
423 
424         @Override
425         public void firstTreeId(short firstTreeId) {
426             throw new IllegalStateException("void firstTreeId(short  firstTreeId) unimplemented ");
427         }
428 
429         @Override
430         public void treeCount(short treeCount) {
431             throw new IllegalStateException("void treeCount(short  treeCount) unimplemented ");
432         }
433 
434         final public int id;
435 
436         final public float threshold;
437         public short firstTreeId = -1;
438         public short treeCount;
439 
440         public Stage(XMLHaarCascadeModel haarCascade, Element stageElement, int id) {
441             this.haarCascade = haarCascade;
442             this.stageElement = stageElement;
443             this.id = id;
444             this.threshold = getFloat(this.stageElement, "stage_threshold");
445             selectChild(this.stageElement, "trees").ifPresent(treeXML -> {
446                 forEachElement(treeXML, e -> e.getNodeName().equals("_"), treeXMLElement -> {
447                             Tree tree = new Tree(this, treeXMLElement, haarCascade.trees.size());
448                             haarCascade.trees.add(tree);
449                             if (firstTreeId == -1) {
450                                 firstTreeId = (short) tree.id;
451                             }
452                             treeCount = (short) (tree.id - firstTreeId + 1);
453                         }
454                 );
455             });
456 
457         }
458     }
459 
460     final public int width;
461     final public int height;
462 
463     public static XMLHaarCascadeModel load(InputStream is) throws IOException, SAXException, ParserConfigurationException {
464         if (is == null) {
465             throw new IllegalArgumentException("input == null!");
466         }
467         org.w3c.dom.Document doc = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(is);
468         doc.getDocumentElement().normalize();
469         Element root = doc.getDocumentElement();
470         Element cascadeElement = selectChild(root, (e) -> e.hasAttribute("type_id")).get();
471 
472         return new XMLHaarCascadeModel(cascadeElement);
473     }
474 
475     XMLHaarCascadeModel(Element cascadeElement) {
476         this.cascadeElement = cascadeElement;
477         Optional<Element> size = selectChild(cascadeElement, "size");
478         if (size.isPresent()) {
479             Scanner sizeScanner = new Scanner(size.get().getTextContent());
480             this.width = sizeScanner.nextInt();
481             this.height = sizeScanner.nextInt();
482         }else{
483             if (selectChild(cascadeElement,"width") instanceof Optional<Element> optWidth && optWidth.isPresent()
484                && selectChild(cascadeElement,"height") instanceof Optional<Element> optHeight && optHeight.isPresent()
485             ){
486                 this.width = Integer.parseInt(optWidth.get().getTextContent());
487                 this.height = Integer.parseInt(optHeight.get().getTextContent());
488                 //System.out.println("height width = "+this.width +  " "+this.height);
489             }else{
490                 throw new IllegalStateException("No width/height or size element in cascade ");
491             }
492         }
493         selectChild(cascadeElement, "stages").ifPresent(stagesXML ->
494                 forEachElement(stagesXML, e -> e.getNodeName().equals("_"),
495                         (stageXMLElement) ->
496                                 stages.add(new Stage(this, stageXMLElement, stages.size()))
497                 )
498         );
499 
500 
501     }
502 }