1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.opgen;
 27 
 28 import oracle.code.json.*;
 29 import oracle.code.onnx.OpSchema;
 30 
 31 import java.io.ByteArrayOutputStream;
 32 import java.io.IOException;
 33 import java.io.ObjectOutputStream;
 34 import java.io.UncheckedIOException;
 35 import java.lang.reflect.Constructor;
 36 import java.lang.reflect.ParameterizedType;
 37 import java.lang.reflect.RecordComponent;
 38 import java.lang.reflect.Type;
 39 import java.nio.file.Files;
 40 import java.nio.file.Path;
 41 import java.nio.file.StandardOpenOption;
 42 import java.util.ArrayList;
 43 import java.util.List;
 44 import java.util.stream.Stream;
 45 
 46 public class OpSchemaParser {
 47 
 48     public static void main(String[] args) throws Exception {
 49         byte[] serSchemas = serialize(Path.of(
 50                 "opgen/onnx-schema.json"));
 51         Files.write(Path.of("opgen/op-schemas.ser"), serSchemas, StandardOpenOption.CREATE_NEW);
 52 
 53         List<OpSchema> parse = parse(Path.of("opgen/onnx-schema.json"));
 54         for (OpSchema opSchema : parse) {
 55             for (OpSchema.Attribute attribute : opSchema.attributes()) {
 56                 if (attribute.default_value() != null) {
 57                     System.out.println(attribute.name() + " : " + attribute.type() + " = " + attribute.default_value());
 58                 }
 59             }
 60 
 61         }
 62 
 63     }
 64 
 65     static byte[] serialize(Path p) throws IOException {
 66         List<OpSchema> parse = parse(p);
 67 
 68         ByteArrayOutputStream baos = new ByteArrayOutputStream();
 69         ObjectOutputStream o = new ObjectOutputStream(baos);
 70         o.writeObject(parse);
 71         o.flush();
 72         return baos.toByteArray();
 73     }
 74 
 75     static List<OpSchema> parse(Path p) {
 76         String schemaString;
 77         try {
 78             schemaString = Files.readString(p);
 79         } catch (IOException e) {
 80             throw new UncheckedIOException(e);
 81         }
 82         JsonValue schemaDoc = Json.parse(schemaString);
 83         return mapJsonArray((JsonArray) schemaDoc, OpSchema.class);
 84     }
 85 
 86     @SuppressWarnings({"unchecked", "rawtypes"})
 87     static <T> T mapJsonValue(JsonValue v, Class<T> c, Type gt) {
 88         return switch (v) {
 89             case JsonBoolean b when c == boolean.class -> (T) (Boolean) b.value();
 90 
 91             case JsonNull _ when c == Object.class -> null;
 92 
 93             case JsonString s when c.isEnum() -> (T) Enum.valueOf((Class<Enum>) c, s.value());
 94             case JsonString s when c == String.class -> (T) s.value();
 95             case JsonString s when c == Object.class -> (T) s.value();
 96 
 97             // Coerce to int when int is declared
 98             case JsonNumber n when c == int.class -> (T) (Integer) n.toNumber().intValue();
 99             // Coerce to int when Object is declared and when integral JSON number
100             case JsonNumber n when n.toNumber() instanceof Long i && c == Object.class -> (T) (Integer) i.intValue();
101             // Coerce to float when Object is declared and when real JSON number
102             case JsonNumber n when n.toNumber() instanceof Double d && c == Object.class -> (T) (Float) d.floatValue();
103 
104             case JsonArray a when c == List.class -> switch (gt) {
105                 case ParameterizedType pt when pt.getActualTypeArguments()[0] instanceof Class<?> tc ->
106                         (T) mapJsonArray(a, tc);
107                 default -> throw new IllegalStateException();
108             };
109 
110             case JsonObject o when Record.class.isAssignableFrom(c) -> (T) mapJsonObject(o, (Class<Record>) c);
111             case JsonObject o when c == List.class -> switch (gt) {
112                 case ParameterizedType pt when pt.getActualTypeArguments()[0] instanceof Class<?> tc ->
113                         (T) mapJsonObjectAsIfJsonArray(o, tc);
114                 default -> throw new IllegalStateException();
115             };
116 
117             default -> throw new IllegalStateException(v + " " + c);
118         };
119     }
120 
121     static <T> List<T> mapJsonObjectAsIfJsonArray(JsonObject o, Class<T> ct) {
122         return o.members().values().stream().map(v -> mapJsonValue(v, ct, ct)).toList();
123     }
124 
125     static <T> List<T> mapJsonArray(JsonArray a, Class<T> ct) {
126         return a.values().stream().map(v -> mapJsonValue(v, ct, ct)).toList();
127     }
128 
129     static <T extends Record> T mapJsonObject(JsonObject o, Class<T> r) {
130         List<Object> rcInstances = new ArrayList<>();
131         for (RecordComponent rc : r.getRecordComponents()) {
132             JsonValue jsonValue = o.members().get(rc.getName());
133             if (jsonValue == null) {
134                 throw new IllegalStateException();
135             }
136             Object instance = mapJsonValue(jsonValue, rc.getType(), rc.getGenericType());
137             rcInstances.add(instance);
138         }
139 
140         Class<?>[] parameters = Stream.of(r.getRecordComponents())
141                 .map(RecordComponent::getType).toArray(Class[]::new);
142         try {
143             Constructor<T> declaredConstructor = r.getDeclaredConstructor(parameters);
144             return declaredConstructor.newInstance(rcInstances.toArray());
145         } catch (ReflectiveOperationException e) {
146             throw new RuntimeException(e);
147         }
148     }
149 }