1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.opgen;
 27 
 28 import jdk.incubator.code.extern.ExternalizedTypeElement;
 29 import oracle.code.onnx.OpSchema;
 30 import oracle.code.onnx.Tensor;
 31 
 32 import java.io.*;
 33 import java.nio.file.Path;
 34 import java.util.*;
 35 import java.util.stream.Collectors;
 36 
 37 import static java.util.Comparator.comparing;
 38 import static java.util.stream.Collectors.*;
 39 
 40 public class OperatorGen {
 41 
 42     final SortedMap<String, SortedSet<OpSchema>> schemas;
 43 
 44     OperatorGen(List<OpSchema> schemas) {
 45         this.schemas = schemas.stream().collect(groupingBy(
 46                 OpSchema::name,
 47                 TreeMap::new,
 48                 toCollection(() -> new TreeSet<>(comparing(OpSchema::since_version).reversed())
 49                 )));
 50     }
 51 
 52     static final String ONNX_PACKAGE = "oracle.code.onnx";
 53     static final String ONNX_OPERATORS_CLASS = "OnnxOperators";
 54 
 55     void genOpClass(Path dir) throws IOException {
 56         OutputStreamWriter osw = new OutputStreamWriter(
 57                 new FileOutputStream(dir.resolve(ONNX_OPERATORS_CLASS + ".java").toFile()));
 58         genOpClass(osw);
 59     }
 60 
 61     void genOpClass(Writer w_) throws IOException {
 62         IndentWriter w = new IndentWriter(w_);
 63 
 64         w.write("""
 65                 /*
 66                  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
 67                  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 68                  *
 69                  * This code is free software; you can redistribute it and/or modify it
 70                  * under the terms of the GNU General Public License version 2 only, as
 71                  * published by the Free Software Foundation.  Oracle designates this
 72                  * particular file as subject to the "Classpath" exception as provided
 73                  * by Oracle in the LICENSE file that accompanied this code.
 74                  *
 75                  * This code is distributed in the hope that it will be useful, but WITHOUT
 76                  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 77                  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 78                  * version 2 for more details (a copy is included in the LICENSE file that
 79                  * accompanied this code).
 80                  *
 81                  * You should have received a copy of the GNU General Public License version
 82                  * 2 along with this work; if not, write to the Free Software Foundation,
 83                  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 84                  *
 85                  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 86                  * or visit www.oracle.com if you need additional information or have any
 87                  * questions.
 88                  */
 89                 """);
 90         w.write("// Auto-generated from ONNX op schema\n");
 91         w.write("\n");
 92         w.write("package " + ONNX_PACKAGE + ";\n");
 93         w.write("\n");
 94         w.write("""
 95                 import oracle.code.onnx.ir.OnnxOps;
 96 
 97                 import java.util.Optional;
 98                 import java.util.List;
 99                 import java.util.Map;
100                 """);
101         w.write("\n");
102 
103         w.write("@SuppressWarnings({\"unchecked\", \"OptionalUsedAsFieldOrParameterType\"})\n");
104         w.write("public final class " + ONNX_OPERATORS_CLASS + " extends ExplicitOnnxOperators {\n");
105 
106         w.in();
107 
108         w.write("\n");
109         w.write("private " + ONNX_OPERATORS_CLASS + "() {}\n");
110         w.write("\n");
111 
112         for (OpSchema s : schemas.values().stream().map(SortedSet::getFirst).toList()) {
113             if (skip(s)) {
114                 System.out.println("Skipping " + s.name());
115                 continue;
116             }
117 
118             genMethod(w, s);
119             w.write("\n");
120         }
121         w.out();
122 
123         w.write("}\n");
124         w.flush();
125     }
126 
127     private boolean skip(OpSchema s) {
128         return s.attributes().stream().anyMatch(a ->
129                 a.type() == OpSchema.AttributeType.GRAPH ||
130                         a.type() == OpSchema.AttributeType.GRAPHS);
131     }
132 
133     private void genMethod(IndentWriter w, OpSchema s) throws IOException {
134         Map<String, ExternalizedTypeElement> javaTypeConstraints = javaTypes(typeConstraintMap(s));
135         Set<String> javaTypeVariables = javaTypeVariables(javaTypeConstraints);
136 
137         boolean twoOrMoreResults = s.max_output() > 1 && s.outputs().size() > 1;
138         List<String> recordResultTypeVariables = new ArrayList<>();
139         if (twoOrMoreResults) {
140             for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
141                 if (s.outputs().stream().anyMatch(op -> typeConstraint.type_param_str().equals(op.type_str())) &&
142                         javaTypeVariables.contains(typeConstraint.type_param_str())) {
143                     recordResultTypeVariables.add(typeConstraint.type_param_str());
144                 }
145             }
146 
147             w.write("public record " + s.name() + "Result");
148             if (!recordResultTypeVariables.isEmpty()) {
149                 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
150             }
151             w.write("(");
152             boolean first = true;
153             for (OpSchema.FormalParameter outParam : s.outputs()) {
154                 if (!first) {
155                     w.write(", ");
156                 }
157 
158                 ExternalizedTypeElement outputType = javaTypeConstraints
159                         .computeIfAbsent(outParam.type_str(),
160                                 ts -> javaType("?", parseTypeString(ts)));
161 
162                 w.write(outputType.toString());
163                 w.write(" " + outParam.name());
164 
165                 first = false;
166             }
167 
168             w.write(") { }\n");
169         }
170 
171         w.write("public static ");
172 
173         if (!s.type_constraints().isEmpty()) {
174             List<String> typeVariables = new ArrayList<>();
175             for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
176                 if (javaTypeVariables.contains(typeConstraint.type_param_str())) {
177                     typeVariables.add(typeConstraint.type_param_str());
178                 }
179             }
180 
181             if (!typeVariables.isEmpty()) {
182                 w.write(typeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
183                 w.write(" ");
184             }
185         }
186 
187         // @@@ Multiple output parameters - need to return tuple/record
188         final ExternalizedTypeElement outputType;
189         if (s.min_output() == 1 && s.max_output() == 1) {
190             OpSchema.FormalParameter outParam = s.outputs().getFirst();
191 
192             outputType = javaTypeConstraints.computeIfAbsent(outParam.type_str(),
193                     ts -> javaType("?", parseTypeString(ts)));
194             w.write(outputType.toString());
195         } else if (s.min_output() == 0 && s.max_output() == 1) {
196             // This does not occur
197             throw new UnsupportedOperationException();
198         } else if (s.outputs().size() == 1) {
199             OpSchema.FormalParameter outParam = s.outputs().getFirst();
200             assert outParam.option() == OpSchema.FormalParameterOption.Variadic;
201 
202             outputType = new ExternalizedTypeElement("List",
203                     List.of(javaTypeConstraints.computeIfAbsent(outParam.type_str(),
204                             ts -> javaType("?", parseTypeString(ts)))));
205             w.write(outputType.toString());
206         } else {
207             assert twoOrMoreResults;
208 
209             outputType = new ExternalizedTypeElement(s.name() + "Result", List.of());
210             w.write(outputType.toString());
211             if (!recordResultTypeVariables.isEmpty()) {
212                 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
213             }
214         }
215         w.write(" ");
216         w.write(s.name() + "(");
217 
218         boolean first = true;
219         for (OpSchema.FormalParameter inParam : s.inputs()) {
220             if (!first) {
221                 w.write(", ");
222             }
223 
224             final ExternalizedTypeElement inputType = javaTypeConstraints
225                     .computeIfAbsent(inParam.type_str(),
226                             ts -> javaType("?", parseTypeString(ts)));
227             switch (inParam.option()) {
228                 case Single -> {
229                     w.write(inputType.toString());
230                 }
231                 case Optional -> {
232                     w.write("Optional<");
233                     w.write(inputType.toString());
234                     w.write(">");
235                 }
236                 case Variadic -> {
237                     w.write("List<");
238                     w.write(inputType.toString());
239                     w.write(">");
240                 }
241             }
242             w.write(" ");
243             w.write(inParam.name());
244 
245             first = false;
246         }
247 
248         for (OpSchema.Attribute attribute : s.attributes()) {
249             if (!first) {
250                 w.write(", ");
251             }
252 
253             OpSchema.AttributeType aType = attribute.type();
254             String typeString = switch (aType) {
255                 default -> {
256                     if (attribute.required()) {
257                         yield aType.type().getSimpleName();
258                     } else {
259                         yield toBoxType(aType.type()).getSimpleName();
260                     }
261                 }
262             };
263             if (attribute.required()) {
264                 w.write(typeString);
265             } else {
266                 w.write("Optional<");
267                 w.write(typeString);
268                 w.write(">");
269             }
270             w.write(" ");
271             w.write(attribute.name());
272 
273             first = false;
274         }
275 
276         w.write(") {\n");
277         w.in();
278 
279         w.write("Object result = OnnxInterpreter.interpret(");
280         w.write("OnnxOps." + s.name() + ".class");
281         w.write(", ");
282 
283         w.write("List.of(");
284         first = true;
285         for (OpSchema.FormalParameter inParam : s.inputs()) {
286             if (!first) {
287                 w.write(", ");
288             }
289 
290             w.write(inParam.name());
291             first = false;
292         }
293         w.write(")");
294         w.write(", ");
295 
296         w.write("List.of(");
297         first = true;
298         for (OpSchema.Attribute attribute : s.attributes()) {
299             if (!first) {
300                 w.write(", ");
301             }
302 
303             w.write(attribute.name());
304             first = false;
305         }
306         w.write(")");
307         w.write(");\n");
308 
309         if (twoOrMoreResults) {
310             w.write("Object[] resultArray = (Object[]) result;\n");
311             w.write("return new " + s.name() + "Result");
312             if (!recordResultTypeVariables.isEmpty()) {
313                 w.write("<>");
314             }
315             w.write("(");
316             first = true;
317             for (int i = 0; i < s.outputs().size(); i++) {
318                 if (!first) {
319                     w.write(", ");
320                 }
321 
322                 w.write("(");
323                 //
324                 final ExternalizedTypeElement t = javaTypeConstraints
325                         .computeIfAbsent(s.outputs().get(i).type_str(),
326                                 ts -> javaType("?", parseTypeString(ts)));
327                 w.write(t.toString());
328                 w.write(")");
329                 w.write("resultArray[" + i + "]");
330                 first = false;
331             }
332             w.write(");\n");
333         } else {
334             w.write("return (" + outputType + ") result;\n");
335         }
336         w.out();
337         w.write("}\n");
338     }
339 
340     static Map<String, ExternalizedTypeElement> javaTypes(Map<String, ExternalizedTypeElement> tcm) {
341         return tcm.entrySet().stream().collect(Collectors.toMap(
342                 e -> e.getKey(),
343                 e -> javaType(e.getKey(), e.getValue())));
344     }
345 
346     static Set<String> javaTypeVariables(Map<String, ExternalizedTypeElement> tcm) {
347         return tcm.entrySet().stream()
348                 .filter(e -> usesTypeVariable(e.getKey(), e.getValue()))
349                 .map(e -> e.getKey())
350                 .collect(Collectors.toSet());
351     }
352 
353     static boolean usesTypeVariable(String typeVariable, ExternalizedTypeElement ete) {
354         if (ete.arguments().isEmpty()) {
355             return typeVariable.equals(ete.identifier());
356         }
357         return ete.arguments().stream().anyMatch(a -> usesTypeVariable(typeVariable, a));
358     }
359 
360     static ExternalizedTypeElement javaType(String typeVariable, ExternalizedTypeElement ete) {
361         String javaIdentifier = switch (ete.identifier()) {
362             case "seq" -> "List";
363             case "sequence" -> "List";
364             case "map" -> "Map";
365             case "optional" -> "Optional";
366             case "tensor" -> "Tensor";
367             case "?" -> typeVariable;
368 
369             default -> {
370                 Tensor.ElementType elementType = Tensor.ElementType.fromOnnxName(ete.identifier());
371                 Class<?> type = elementType.type();
372                 if (type.isPrimitive()) {
373                     yield toBoxType(type).getSimpleName();
374                 } else {
375                     yield type.getSimpleName();
376                 }
377             }
378         };
379 
380         if (ete.identifier().equals("map") &&
381                 ete.arguments().stream().allMatch(t -> t.identifier().equals("?"))) {
382             return new ExternalizedTypeElement(javaIdentifier,
383                     ete.arguments().stream().map(c -> javaType("?", c)).toList());
384         }
385 
386         return new ExternalizedTypeElement(javaIdentifier,
387                 ete.arguments().stream().map(c -> javaType(typeVariable, c)).toList());
388     }
389 
390     static Map<String, ExternalizedTypeElement> typeConstraintMap(OpSchema s) {
391         return s.type_constraints().stream().collect(toMap(
392                 tc -> tc.type_param_str(),
393                 tc -> tc.allowed_type_strs().stream().map(OperatorGen::parseTypeString).reduce(OperatorGen::lub).orElseThrow()));
394     }
395 
396     static ExternalizedTypeElement lub(ExternalizedTypeElement a,
397                                                    ExternalizedTypeElement b) {
398         if (!a.identifier().equals(b.identifier())) {
399             return new ExternalizedTypeElement("?", List.of());
400         }
401 
402         assert a.arguments().size() == b.arguments().size();
403 
404         List<ExternalizedTypeElement> children = new ArrayList<>();
405         for (int i = 0; i < a.arguments().size(); i++) {
406             children.add(lub(a.arguments().get(i), b.arguments().get(i)));
407         }
408 
409         return new ExternalizedTypeElement(a.identifier(), children);
410     }
411 
412     static ExternalizedTypeElement parseTypeString(String type_str) {
413         return ExternalizedTypeElement.ofString(
414                 type_str.replace('(', '<').replace(')', '>'));
415     }
416 
417     static Class<?> toBoxType(Class<?> pc) {
418         if (pc == byte.class) {
419             return Byte.class;
420         } else if (pc == short.class) {
421             return Short.class;
422         } else if (pc == int.class) {
423             return Integer.class;
424         } else if (pc == long.class) {
425             return Long.class;
426         } else if (pc == float.class) {
427             return Float.class;
428         } else if (pc == double.class) {
429             return Double.class;
430         } else if (pc == boolean.class) {
431             return Boolean.class;
432         } else {
433             return pc;
434         }
435     }
436 
437     public static void main(String[] args) throws Exception {
438         List<OpSchema> schemas = OpSchemaParser.parse(Path.of(
439                 "opgen/onnx-schema.json"));
440         OperatorGen oprGen = new OperatorGen(schemas);
441 
442         oprGen.genOpClass(Path.of("src/main/java/oracle/code/onnx"));
443     }
444 }