1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.opgen;
27
28 import oracle.code.json.*;
29 import oracle.code.onnx.OpSchema;
30
31 import java.io.ByteArrayOutputStream;
32 import java.io.IOException;
33 import java.io.ObjectOutputStream;
34 import java.io.UncheckedIOException;
35 import java.lang.reflect.Constructor;
36 import java.lang.reflect.ParameterizedType;
37 import java.lang.reflect.RecordComponent;
38 import java.lang.reflect.Type;
39 import java.nio.file.Files;
40 import java.nio.file.Path;
41 import java.nio.file.StandardOpenOption;
42 import java.util.ArrayList;
43 import java.util.List;
44 import java.util.stream.Stream;
45
46 public class OpSchemaParser {
47
48 public static void main(String[] args) throws Exception {
49 byte[] serSchemas = serialize(Path.of(
50 "opgen/onnx-schema.json"));
51 Files.write(Path.of("opgen/op-schemas.ser"), serSchemas, StandardOpenOption.CREATE_NEW);
52
53 List<OpSchema> parse = parse(Path.of("opgen/onnx-schema.json"));
54 for (OpSchema opSchema : parse) {
55 for (OpSchema.Attribute attribute : opSchema.attributes()) {
56 if (attribute.default_value() != null) {
57 System.out.println(attribute.name() + " : " + attribute.type() + " = " + attribute.default_value());
58 }
59 }
60
61 }
62
63 }
64
65 static byte[] serialize(Path p) throws IOException {
66 List<OpSchema> parse = parse(p);
67
68 ByteArrayOutputStream baos = new ByteArrayOutputStream();
69 ObjectOutputStream o = new ObjectOutputStream(baos);
70 o.writeObject(parse);
71 o.flush();
72 return baos.toByteArray();
73 }
74
75 static List<OpSchema> parse(Path p) {
76 String schemaString;
77 try {
78 schemaString = Files.readString(p);
79 } catch (IOException e) {
80 throw new UncheckedIOException(e);
81 }
82 JsonValue schemaDoc = Json.parse(schemaString);
83 return mapJsonArray((JsonArray) schemaDoc, OpSchema.class);
84 }
85
86 @SuppressWarnings({"unchecked", "rawtypes"})
87 static <T> T mapJsonValue(JsonValue v, Class<T> c, Type gt) {
88 return switch (v) {
89 case JsonBoolean b when c == boolean.class -> (T) (Boolean) b.value();
90
91 case JsonNull _ when c == Object.class -> null;
92
93 case JsonString s when c.isEnum() -> (T) Enum.valueOf((Class<Enum>) c, s.value());
94 case JsonString s when c == String.class -> (T) s.value();
95 case JsonString s when c == Object.class -> (T) s.value();
96
97 // Coerce to int when int is declared
98 case JsonNumber n when c == int.class -> (T) (Integer) n.toNumber().intValue();
99 // Coerce to int when Object is declared and when integral JSON number
100 case JsonNumber n when n.toNumber() instanceof Long i && c == Object.class -> (T) (Integer) i.intValue();
101 // Coerce to float when Object is declared and when real JSON number
102 case JsonNumber n when n.toNumber() instanceof Double d && c == Object.class -> (T) (Float) d.floatValue();
103
104 case JsonArray a when c == List.class -> switch (gt) {
105 case ParameterizedType pt when pt.getActualTypeArguments()[0] instanceof Class<?> tc ->
106 (T) mapJsonArray(a, tc);
107 default -> throw new IllegalStateException();
108 };
109
110 case JsonObject o when Record.class.isAssignableFrom(c) -> (T) mapJsonObject(o, (Class<Record>) c);
111 case JsonObject o when c == List.class -> switch (gt) {
112 case ParameterizedType pt when pt.getActualTypeArguments()[0] instanceof Class<?> tc ->
113 (T) mapJsonObjectAsIfJsonArray(o, tc);
114 default -> throw new IllegalStateException();
115 };
116
117 default -> throw new IllegalStateException(v + " " + c);
118 };
119 }
120
121 static <T> List<T> mapJsonObjectAsIfJsonArray(JsonObject o, Class<T> ct) {
122 return o.members().values().stream().map(v -> mapJsonValue(v, ct, ct)).toList();
123 }
124
125 static <T> List<T> mapJsonArray(JsonArray a, Class<T> ct) {
126 return a.values().stream().map(v -> mapJsonValue(v, ct, ct)).toList();
127 }
128
129 static <T extends Record> T mapJsonObject(JsonObject o, Class<T> r) {
130 List<Object> rcInstances = new ArrayList<>();
131 for (RecordComponent rc : r.getRecordComponents()) {
132 JsonValue jsonValue = o.members().get(rc.getName());
133 if (jsonValue == null) {
134 throw new IllegalStateException();
135 }
136 Object instance = mapJsonValue(jsonValue, rc.getType(), rc.getGenericType());
137 rcInstances.add(instance);
138 }
139
140 Class<?>[] parameters = Stream.of(r.getRecordComponents())
141 .map(RecordComponent::getType).toArray(Class[]::new);
142 try {
143 Constructor<T> declaredConstructor = r.getDeclaredConstructor(parameters);
144 return declaredConstructor.newInstance(rcInstances.toArray());
145 } catch (ReflectiveOperationException e) {
146 throw new RuntimeException(e);
147 }
148 }
149 }