1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.opgen;
27
28 import jdk.incubator.code.extern.ExternalizedTypeElement;
29 import oracle.code.onnx.OpSchema;
30 import oracle.code.onnx.Tensor;
31
32 import java.io.*;
33 import java.nio.file.Path;
34 import java.util.*;
35 import java.util.stream.Collectors;
36
37 import static java.util.Comparator.comparing;
38 import static java.util.stream.Collectors.*;
39
40 public class OperatorGen {
41
42 final SortedMap<String, SortedSet<OpSchema>> schemas;
43
44 OperatorGen(List<OpSchema> schemas) {
45 this.schemas = schemas.stream().collect(groupingBy(
46 OpSchema::name,
47 TreeMap::new,
48 toCollection(() -> new TreeSet<>(comparing(OpSchema::since_version).reversed())
49 )));
50 }
51
52 static final String ONNX_PACKAGE = "oracle.code.onnx";
53 static final String ONNX_OPERATORS_CLASS = "OnnxOperators";
54
55 void genOpClass(Path dir) throws IOException {
56 OutputStreamWriter osw = new OutputStreamWriter(
57 new FileOutputStream(dir.resolve(ONNX_OPERATORS_CLASS + ".java").toFile()));
58 genOpClass(osw);
59 }
60
61 void genOpClass(Writer w_) throws IOException {
62 IndentWriter w = new IndentWriter(w_);
63
64 w.write("""
65 /*
66 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
67 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
68 *
69 * This code is free software; you can redistribute it and/or modify it
70 * under the terms of the GNU General Public License version 2 only, as
71 * published by the Free Software Foundation. Oracle designates this
72 * particular file as subject to the "Classpath" exception as provided
73 * by Oracle in the LICENSE file that accompanied this code.
74 *
75 * This code is distributed in the hope that it will be useful, but WITHOUT
76 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
77 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
78 * version 2 for more details (a copy is included in the LICENSE file that
79 * accompanied this code).
80 *
81 * You should have received a copy of the GNU General Public License version
82 * 2 along with this work; if not, write to the Free Software Foundation,
83 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
84 *
85 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
86 * or visit www.oracle.com if you need additional information or have any
87 * questions.
88 */
89 """);
90 w.write("// Auto-generated from ONNX op schema\n");
91 w.write("\n");
92 w.write("package " + ONNX_PACKAGE + ";\n");
93 w.write("\n");
94 w.write("""
95 import oracle.code.onnx.ir.OnnxOps;
96
97 import java.util.Optional;
98 import java.util.List;
99 import java.util.Map;
100 """);
101 w.write("\n");
102
103 w.write("@SuppressWarnings({\"unchecked\", \"OptionalUsedAsFieldOrParameterType\"})\n");
104 w.write("public final class " + ONNX_OPERATORS_CLASS + " extends ExplicitOnnxOperators {\n");
105
106 w.in();
107
108 w.write("\n");
109 w.write("private " + ONNX_OPERATORS_CLASS + "() {}\n");
110 w.write("\n");
111
112 for (OpSchema s : schemas.values().stream().map(SortedSet::getFirst).toList()) {
113 if (skip(s)) {
114 System.out.println("Skipping " + s.name());
115 continue;
116 }
117
118 genMethod(w, s);
119 w.write("\n");
120 }
121 w.out();
122
123 w.write("}\n");
124 w.flush();
125 }
126
127 private boolean skip(OpSchema s) {
128 return s.attributes().stream().anyMatch(a ->
129 a.type() == OpSchema.AttributeType.GRAPH ||
130 a.type() == OpSchema.AttributeType.GRAPHS);
131 }
132
133 private void genMethod(IndentWriter w, OpSchema s) throws IOException {
134 Map<String, ExternalizedTypeElement> javaTypeConstraints = javaTypes(typeConstraintMap(s));
135 Set<String> javaTypeVariables = javaTypeVariables(javaTypeConstraints);
136
137 boolean twoOrMoreResults = s.max_output() > 1 && s.outputs().size() > 1;
138 List<String> recordResultTypeVariables = new ArrayList<>();
139 if (twoOrMoreResults) {
140 for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
141 if (s.outputs().stream().anyMatch(op -> typeConstraint.type_param_str().equals(op.type_str())) &&
142 javaTypeVariables.contains(typeConstraint.type_param_str())) {
143 recordResultTypeVariables.add(typeConstraint.type_param_str());
144 }
145 }
146
147 w.write("public record " + s.name() + "Result");
148 if (!recordResultTypeVariables.isEmpty()) {
149 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
150 }
151 w.write("(");
152 boolean first = true;
153 for (OpSchema.FormalParameter outParam : s.outputs()) {
154 if (!first) {
155 w.write(", ");
156 }
157
158 ExternalizedTypeElement outputType = javaTypeConstraints
159 .computeIfAbsent(outParam.type_str(),
160 ts -> javaType("?", parseTypeString(ts)));
161
162 w.write(outputType.toString());
163 w.write(" " + outParam.name());
164
165 first = false;
166 }
167
168 w.write(") { }\n");
169 }
170
171 w.write("public static ");
172
173 if (!s.type_constraints().isEmpty()) {
174 List<String> typeVariables = new ArrayList<>();
175 for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
176 if (javaTypeVariables.contains(typeConstraint.type_param_str())) {
177 typeVariables.add(typeConstraint.type_param_str());
178 }
179 }
180
181 if (!typeVariables.isEmpty()) {
182 w.write(typeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
183 w.write(" ");
184 }
185 }
186
187 // @@@ Multiple output parameters - need to return tuple/record
188 final ExternalizedTypeElement outputType;
189 if (s.min_output() == 1 && s.max_output() == 1) {
190 OpSchema.FormalParameter outParam = s.outputs().getFirst();
191
192 outputType = javaTypeConstraints.computeIfAbsent(outParam.type_str(),
193 ts -> javaType("?", parseTypeString(ts)));
194 w.write(outputType.toString());
195 } else if (s.min_output() == 0 && s.max_output() == 1) {
196 // This does not occur
197 throw new UnsupportedOperationException();
198 } else if (s.outputs().size() == 1) {
199 OpSchema.FormalParameter outParam = s.outputs().getFirst();
200 assert outParam.option() == OpSchema.FormalParameterOption.Variadic;
201
202 outputType = new ExternalizedTypeElement("List",
203 List.of(javaTypeConstraints.computeIfAbsent(outParam.type_str(),
204 ts -> javaType("?", parseTypeString(ts)))));
205 w.write(outputType.toString());
206 } else {
207 assert twoOrMoreResults;
208
209 outputType = new ExternalizedTypeElement(s.name() + "Result", List.of());
210 w.write(outputType.toString());
211 if (!recordResultTypeVariables.isEmpty()) {
212 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
213 }
214 }
215 w.write(" ");
216 w.write(s.name() + "(");
217
218 boolean first = true;
219 for (OpSchema.FormalParameter inParam : s.inputs()) {
220 if (!first) {
221 w.write(", ");
222 }
223
224 final ExternalizedTypeElement inputType = javaTypeConstraints
225 .computeIfAbsent(inParam.type_str(),
226 ts -> javaType("?", parseTypeString(ts)));
227 switch (inParam.option()) {
228 case Single -> {
229 w.write(inputType.toString());
230 }
231 case Optional -> {
232 w.write("Optional<");
233 w.write(inputType.toString());
234 w.write(">");
235 }
236 case Variadic -> {
237 w.write("List<");
238 w.write(inputType.toString());
239 w.write(">");
240 }
241 }
242 w.write(" ");
243 w.write(inParam.name());
244
245 first = false;
246 }
247
248 for (OpSchema.Attribute attribute : s.attributes()) {
249 if (!first) {
250 w.write(", ");
251 }
252
253 OpSchema.AttributeType aType = attribute.type();
254 String typeString = switch (aType) {
255 default -> {
256 if (attribute.required()) {
257 yield aType.type().getSimpleName();
258 } else {
259 yield toBoxType(aType.type()).getSimpleName();
260 }
261 }
262 };
263 if (attribute.required()) {
264 w.write(typeString);
265 } else {
266 w.write("Optional<");
267 w.write(typeString);
268 w.write(">");
269 }
270 w.write(" ");
271 w.write(attribute.name());
272
273 first = false;
274 }
275
276 w.write(") {\n");
277 w.in();
278
279 w.write("Object result = OnnxInterpreter.interpret(");
280 w.write("OnnxOps." + s.name() + ".class");
281 w.write(", ");
282
283 w.write("List.of(");
284 first = true;
285 for (OpSchema.FormalParameter inParam : s.inputs()) {
286 if (!first) {
287 w.write(", ");
288 }
289
290 w.write(inParam.name());
291 first = false;
292 }
293 w.write(")");
294 w.write(", ");
295
296 w.write("List.of(");
297 first = true;
298 for (OpSchema.Attribute attribute : s.attributes()) {
299 if (!first) {
300 w.write(", ");
301 }
302
303 w.write(attribute.name());
304 first = false;
305 }
306 w.write(")");
307 w.write(");\n");
308
309 if (twoOrMoreResults) {
310 w.write("Object[] resultArray = (Object[]) result;\n");
311 w.write("return new " + s.name() + "Result");
312 if (!recordResultTypeVariables.isEmpty()) {
313 w.write("<>");
314 }
315 w.write("(");
316 first = true;
317 for (int i = 0; i < s.outputs().size(); i++) {
318 if (!first) {
319 w.write(", ");
320 }
321
322 w.write("(");
323 //
324 final ExternalizedTypeElement t = javaTypeConstraints
325 .computeIfAbsent(s.outputs().get(i).type_str(),
326 ts -> javaType("?", parseTypeString(ts)));
327 w.write(t.toString());
328 w.write(")");
329 w.write("resultArray[" + i + "]");
330 first = false;
331 }
332 w.write(");\n");
333 } else {
334 w.write("return (" + outputType + ") result;\n");
335 }
336 w.out();
337 w.write("}\n");
338 }
339
340 static Map<String, ExternalizedTypeElement> javaTypes(Map<String, ExternalizedTypeElement> tcm) {
341 return tcm.entrySet().stream().collect(Collectors.toMap(
342 e -> e.getKey(),
343 e -> javaType(e.getKey(), e.getValue())));
344 }
345
346 static Set<String> javaTypeVariables(Map<String, ExternalizedTypeElement> tcm) {
347 return tcm.entrySet().stream()
348 .filter(e -> usesTypeVariable(e.getKey(), e.getValue()))
349 .map(e -> e.getKey())
350 .collect(Collectors.toSet());
351 }
352
353 static boolean usesTypeVariable(String typeVariable, ExternalizedTypeElement ete) {
354 if (ete.arguments().isEmpty()) {
355 return typeVariable.equals(ete.identifier());
356 }
357 return ete.arguments().stream().anyMatch(a -> usesTypeVariable(typeVariable, a));
358 }
359
360 static ExternalizedTypeElement javaType(String typeVariable, ExternalizedTypeElement ete) {
361 String javaIdentifier = switch (ete.identifier()) {
362 case "seq" -> "List";
363 case "sequence" -> "List";
364 case "map" -> "Map";
365 case "optional" -> "Optional";
366 case "tensor" -> "Tensor";
367 case "?" -> typeVariable;
368
369 default -> {
370 Tensor.ElementType elementType = Tensor.ElementType.fromOnnxName(ete.identifier());
371 Class<?> type = elementType.type();
372 if (type.isPrimitive()) {
373 yield toBoxType(type).getSimpleName();
374 } else {
375 yield type.getSimpleName();
376 }
377 }
378 };
379
380 if (ete.identifier().equals("map") &&
381 ete.arguments().stream().allMatch(t -> t.identifier().equals("?"))) {
382 return new ExternalizedTypeElement(javaIdentifier,
383 ete.arguments().stream().map(c -> javaType("?", c)).toList());
384 }
385
386 return new ExternalizedTypeElement(javaIdentifier,
387 ete.arguments().stream().map(c -> javaType(typeVariable, c)).toList());
388 }
389
390 static Map<String, ExternalizedTypeElement> typeConstraintMap(OpSchema s) {
391 return s.type_constraints().stream().collect(toMap(
392 tc -> tc.type_param_str(),
393 tc -> tc.allowed_type_strs().stream().map(OperatorGen::parseTypeString).reduce(OperatorGen::lub).orElseThrow()));
394 }
395
396 static ExternalizedTypeElement lub(ExternalizedTypeElement a,
397 ExternalizedTypeElement b) {
398 if (!a.identifier().equals(b.identifier())) {
399 return new ExternalizedTypeElement("?", List.of());
400 }
401
402 assert a.arguments().size() == b.arguments().size();
403
404 List<ExternalizedTypeElement> children = new ArrayList<>();
405 for (int i = 0; i < a.arguments().size(); i++) {
406 children.add(lub(a.arguments().get(i), b.arguments().get(i)));
407 }
408
409 return new ExternalizedTypeElement(a.identifier(), children);
410 }
411
412 static ExternalizedTypeElement parseTypeString(String type_str) {
413 return ExternalizedTypeElement.ofString(
414 type_str.replace('(', '<').replace(')', '>'));
415 }
416
417 static Class<?> toBoxType(Class<?> pc) {
418 if (pc == byte.class) {
419 return Byte.class;
420 } else if (pc == short.class) {
421 return Short.class;
422 } else if (pc == int.class) {
423 return Integer.class;
424 } else if (pc == long.class) {
425 return Long.class;
426 } else if (pc == float.class) {
427 return Float.class;
428 } else if (pc == double.class) {
429 return Double.class;
430 } else if (pc == boolean.class) {
431 return Boolean.class;
432 } else {
433 return pc;
434 }
435 }
436
437 public static void main(String[] args) throws Exception {
438 List<OpSchema> schemas = OpSchemaParser.parse(Path.of(
439 "opgen/onnx-schema.json"));
440 OperatorGen oprGen = new OperatorGen(schemas);
441
442 oprGen.genOpClass(Path.of("src/main/java/oracle/code/onnx"));
443 }
444 }