1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.opgen;
27
28 import jdk.incubator.code.extern.ExternalizedCodeType;
29 import oracle.code.onnx.OpSchema;
30 import oracle.code.onnx.Tensor;
31
32 import java.io.*;
33 import java.nio.file.Path;
34 import java.util.*;
35 import java.util.stream.Collectors;
36
37 import static java.util.Comparator.comparing;
38 import static java.util.stream.Collectors.*;
39
40 public class OperatorGen {
41
42 final SortedMap<String, SortedSet<OpSchema>> schemas;
43
44 OperatorGen(List<OpSchema> schemas) {
45 this.schemas = schemas.stream().collect(groupingBy(
46 OpSchema::name,
47 TreeMap::new,
48 toCollection(() -> new TreeSet<>(comparing(OpSchema::since_version).reversed())
49 )));
50 }
51
52 static final String ONNX_PACKAGE = "oracle.code.onnx";
53 static final String ONNX_OPERATORS_CLASS = "OnnxOperators";
54
55 void genOpClass(Path dir) throws IOException {
56 OutputStreamWriter osw = new OutputStreamWriter(
57 new FileOutputStream(dir.resolve(ONNX_OPERATORS_CLASS + ".java").toFile()));
58 genOpClass(osw);
59 }
60
61 void genOpClass(Writer w_) throws IOException {
62 IndentWriter w = new IndentWriter(w_);
63
64 w.write("""
65 /*
66 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
67 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
68 *
69 * This code is free software; you can redistribute it and/or modify it
70 * under the terms of the GNU General Public License version 2 only, as
71 * published by the Free Software Foundation. Oracle designates this
72 * particular file as subject to the "Classpath" exception as provided
73 * by Oracle in the LICENSE file that accompanied this code.
74 *
75 * This code is distributed in the hope that it will be useful, but WITHOUT
76 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
77 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
78 * version 2 for more details (a copy is included in the LICENSE file that
79 * accompanied this code).
80 *
81 * You should have received a copy of the GNU General Public License version
82 * 2 along with this work; if not, write to the Free Software Foundation,
83 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
84 *
85 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
86 * or visit www.oracle.com if you need additional information or have any
87 * questions.
88 */
89 """);
90 w.write("// Auto-generated from ONNX op schema\n");
91 w.write("\n");
92 w.write("package " + ONNX_PACKAGE + ";\n");
93 w.write("\n");
94 w.write("""
95 import oracle.code.onnx.ir.OnnxOps;
96
97 import java.util.Optional;
98 import java.util.List;
99 import java.util.Map;
100 """);
101 w.write("\n");
102
103 w.write("@SuppressWarnings({\"unchecked\", \"OptionalUsedAsFieldOrParameterType\"})\n");
104 w.write("public final class " + ONNX_OPERATORS_CLASS + " extends ExplicitOnnxOperators {\n");
105
106 w.in();
107
108 w.write("\n");
109 w.write("private " + ONNX_OPERATORS_CLASS + "() {}\n");
110 w.write("\n");
111
112 for (OpSchema s : schemas.values().stream().map(SortedSet::getFirst).toList()) {
113 if (skip(s)) {
114 System.out.println("Skipping " + s.name());
115 continue;
116 }
117
118 genMethod(w, s);
119 w.write("\n");
120 }
121 w.out();
122
123 w.write("}\n");
124 w.flush();
125 }
126
127 private boolean skip(OpSchema s) {
128 return s.attributes().stream().anyMatch(a ->
129 a.type() == OpSchema.AttributeType.GRAPH ||
130 a.type() == OpSchema.AttributeType.GRAPHS);
131 }
132
133 private void genMethod(IndentWriter w, OpSchema s) throws IOException {
134 Map<String, ExternalizedCodeType> javaTypeConstraints = javaTypes(typeConstraintMap(s));
135 Set<String> javaTypeVariables = javaTypeVariables(javaTypeConstraints);
136
137 boolean twoOrMoreResults = s.max_output() > 1 && s.outputs().size() > 1;
138 List<String> recordResultTypeVariables = new ArrayList<>();
139 if (twoOrMoreResults) {
140 for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
141 if (s.outputs().stream().anyMatch(op -> typeConstraint.type_param_str().equals(op.type_str())) &&
142 javaTypeVariables.contains(typeConstraint.type_param_str())) {
143 recordResultTypeVariables.add(typeConstraint.type_param_str());
144 }
145 }
146
147 w.write("public record " + s.name() + "Result");
148 if (!recordResultTypeVariables.isEmpty()) {
149 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
150 }
151 w.write("(");
152 boolean first = true;
153 for (OpSchema.FormalParameter outParam : s.outputs()) {
154 if (!first) {
155 w.write(", ");
156 }
157
158 ExternalizedCodeType outputType = javaTypeConstraints
159 .computeIfAbsent(outParam.type_str(),
160 ts -> javaType("?", parseTypeString(ts)));
161
162 w.write(outputType.toString());
163 w.write(" " + outParam.name());
164
165 first = false;
166 }
167
168 w.write(") { }\n");
169 }
170
171 for (String dl : s.doc().split("\\n")) {
172 w.write("/// " + dl + "\n");
173 }
174 w.write("public static ");
175
176 if (!s.type_constraints().isEmpty()) {
177 List<String> typeVariables = new ArrayList<>();
178 for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
179 if (javaTypeVariables.contains(typeConstraint.type_param_str())) {
180 typeVariables.add(typeConstraint.type_param_str());
181 }
182 }
183
184 if (!typeVariables.isEmpty()) {
185 w.write(typeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
186 w.write(" ");
187 }
188 }
189
190 // @@@ Multiple output parameters - need to return tuple/record
191 final ExternalizedCodeType outputType;
192 if (s.min_output() == 1 && s.max_output() == 1) {
193 OpSchema.FormalParameter outParam = s.outputs().getFirst();
194
195 outputType = javaTypeConstraints.computeIfAbsent(outParam.type_str(),
196 ts -> javaType("?", parseTypeString(ts)));
197 w.write(outputType.toString());
198 } else if (s.min_output() == 0 && s.max_output() == 1) {
199 // This does not occur
200 throw new UnsupportedOperationException();
201 } else if (s.outputs().size() == 1) {
202 OpSchema.FormalParameter outParam = s.outputs().getFirst();
203 assert outParam.option() == OpSchema.FormalParameterOption.Variadic;
204
205 outputType = new ExternalizedCodeType("List",
206 List.of(javaTypeConstraints.computeIfAbsent(outParam.type_str(),
207 ts -> javaType("?", parseTypeString(ts)))));
208 w.write(outputType.toString());
209 } else {
210 assert twoOrMoreResults;
211
212 outputType = new ExternalizedCodeType(s.name() + "Result", List.of());
213 w.write(outputType.toString());
214 if (!recordResultTypeVariables.isEmpty()) {
215 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
216 }
217 }
218 w.write(" ");
219 w.write(s.name() + "(");
220
221 boolean first = true;
222 for (OpSchema.FormalParameter inParam : s.inputs()) {
223 if (!first) {
224 w.write(", ");
225 }
226
227 final ExternalizedCodeType inputType = javaTypeConstraints
228 .computeIfAbsent(inParam.type_str(),
229 ts -> javaType("?", parseTypeString(ts)));
230 switch (inParam.option()) {
231 case Single -> {
232 w.write(inputType.toString());
233 }
234 case Optional -> {
235 w.write("Optional<");
236 w.write(inputType.toString());
237 w.write(">");
238 }
239 case Variadic -> {
240 w.write("List<");
241 w.write(inputType.toString());
242 w.write(">");
243 }
244 }
245 w.write(" ");
246 w.write(inParam.name());
247
248 first = false;
249 }
250
251 for (OpSchema.Attribute attribute : s.attributes()) {
252 if (!first) {
253 w.write(", ");
254 }
255
256 OpSchema.AttributeType aType = attribute.type();
257 String typeString = switch (aType) {
258 default -> {
259 if (attribute.required()) {
260 yield aType.type().getSimpleName();
261 } else {
262 yield toBoxType(aType.type()).getSimpleName();
263 }
264 }
265 };
266 if (attribute.required()) {
267 w.write(typeString);
268 } else {
269 w.write("Optional<");
270 w.write(typeString);
271 w.write(">");
272 }
273 w.write(" ");
274 w.write(attribute.name());
275
276 first = false;
277 }
278
279 w.write(") {\n");
280 w.in();
281
282 w.write("Object result = OnnxInterpreter.interpret(");
283 w.write("OnnxOps." + s.name() + ".class");
284 w.write(", ");
285
286 w.write("List.of(");
287 first = true;
288 for (OpSchema.FormalParameter inParam : s.inputs()) {
289 if (!first) {
290 w.write(", ");
291 }
292
293 w.write(inParam.name());
294 first = false;
295 }
296 w.write(")");
297 w.write(", ");
298
299 w.write("List.of(");
300 first = true;
301 for (OpSchema.Attribute attribute : s.attributes()) {
302 if (!first) {
303 w.write(", ");
304 }
305
306 w.write(attribute.name());
307 first = false;
308 }
309 w.write(")");
310 w.write(");\n");
311
312 if (twoOrMoreResults) {
313 w.write("Object[] resultArray = (Object[]) result;\n");
314 w.write("return new " + s.name() + "Result");
315 if (!recordResultTypeVariables.isEmpty()) {
316 w.write("<>");
317 }
318 w.write("(");
319 first = true;
320 for (int i = 0; i < s.outputs().size(); i++) {
321 if (!first) {
322 w.write(", ");
323 }
324
325 w.write("(");
326 //
327 final ExternalizedCodeType t = javaTypeConstraints
328 .computeIfAbsent(s.outputs().get(i).type_str(),
329 ts -> javaType("?", parseTypeString(ts)));
330 w.write(t.toString());
331 w.write(")");
332 w.write("resultArray[" + i + "]");
333 first = false;
334 }
335 w.write(");\n");
336 } else {
337 w.write("return (" + outputType + ") result;\n");
338 }
339 w.out();
340 w.write("}\n");
341 }
342
343 static Map<String, ExternalizedCodeType> javaTypes(Map<String, ExternalizedCodeType> tcm) {
344 return tcm.entrySet().stream().collect(Collectors.toMap(
345 e -> e.getKey(),
346 e -> javaType(e.getKey(), e.getValue())));
347 }
348
349 static Set<String> javaTypeVariables(Map<String, ExternalizedCodeType> tcm) {
350 return tcm.entrySet().stream()
351 .filter(e -> usesTypeVariable(e.getKey(), e.getValue()))
352 .map(e -> e.getKey())
353 .collect(Collectors.toSet());
354 }
355
356 static boolean usesTypeVariable(String typeVariable, ExternalizedCodeType ete) {
357 if (ete.arguments().isEmpty()) {
358 return typeVariable.equals(ete.identifier());
359 }
360 return ete.arguments().stream().anyMatch(a -> usesTypeVariable(typeVariable, a));
361 }
362
363 static ExternalizedCodeType javaType(String typeVariable, ExternalizedCodeType ete) {
364 String javaIdentifier = switch (ete.identifier()) {
365 case "seq" -> "List";
366 case "sequence" -> "List";
367 case "map" -> "Map";
368 case "optional" -> "Optional";
369 case "tensor" -> "Tensor";
370 case "?" -> typeVariable;
371
372 default -> {
373 Tensor.ElementType elementType = Tensor.ElementType.fromOnnxName(ete.identifier());
374 Class<?> type = elementType.type();
375 if (type.isPrimitive()) {
376 yield toBoxType(type).getSimpleName();
377 } else {
378 yield type.getSimpleName();
379 }
380 }
381 };
382
383 if (ete.identifier().equals("map") &&
384 ete.arguments().stream().allMatch(t -> t.identifier().equals("?"))) {
385 return new ExternalizedCodeType(javaIdentifier,
386 ete.arguments().stream().map(c -> javaType("?", c)).toList());
387 }
388
389 return new ExternalizedCodeType(javaIdentifier,
390 ete.arguments().stream().map(c -> javaType(typeVariable, c)).toList());
391 }
392
393 static Map<String, ExternalizedCodeType> typeConstraintMap(OpSchema s) {
394 return s.type_constraints().stream().collect(toMap(
395 tc -> tc.type_param_str(),
396 tc -> tc.allowed_type_strs().stream().map(OperatorGen::parseTypeString).reduce(OperatorGen::lub).orElseThrow()));
397 }
398
399 static ExternalizedCodeType lub(ExternalizedCodeType a,
400 ExternalizedCodeType b) {
401 if (!a.identifier().equals(b.identifier())) {
402 return new ExternalizedCodeType("?", List.of());
403 }
404
405 assert a.arguments().size() == b.arguments().size();
406
407 List<ExternalizedCodeType> children = new ArrayList<>();
408 for (int i = 0; i < a.arguments().size(); i++) {
409 children.add(lub(a.arguments().get(i), b.arguments().get(i)));
410 }
411
412 return new ExternalizedCodeType(a.identifier(), children);
413 }
414
415 static ExternalizedCodeType parseTypeString(String type_str) {
416 return ExternalizedCodeType.ofString(
417 type_str.replace('(', '<').replace(')', '>'));
418 }
419
420 static Class<?> toBoxType(Class<?> pc) {
421 if (pc == byte.class) {
422 return Byte.class;
423 } else if (pc == short.class) {
424 return Short.class;
425 } else if (pc == int.class) {
426 return Integer.class;
427 } else if (pc == long.class) {
428 return Long.class;
429 } else if (pc == float.class) {
430 return Float.class;
431 } else if (pc == double.class) {
432 return Double.class;
433 } else if (pc == boolean.class) {
434 return Boolean.class;
435 } else {
436 return pc;
437 }
438 }
439
440 public static void main(String[] args) throws Exception {
441 List<OpSchema> schemas = OpSchemaParser.parse(Path.of(
442 "opgen/onnx-schema.json"));
443 OperatorGen oprGen = new OperatorGen(schemas);
444
445 oprGen.genOpClass(Path.of("src/main/java/oracle/code/onnx"));
446 }
447 }