1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.opgen;
 27 
 28 import jdk.incubator.code.extern.ExternalizedCodeType;
 29 import oracle.code.onnx.OpSchema;
 30 import oracle.code.onnx.Tensor;
 31 
 32 import java.io.*;
 33 import java.nio.file.Path;
 34 import java.util.*;
 35 import java.util.stream.Collectors;
 36 
 37 import static java.util.Comparator.comparing;
 38 import static java.util.stream.Collectors.*;
 39 
 40 public class OperatorGen {
 41 
 42     final SortedMap<String, SortedSet<OpSchema>> schemas;
 43 
 44     OperatorGen(List<OpSchema> schemas) {
 45         this.schemas = schemas.stream().collect(groupingBy(
 46                 OpSchema::name,
 47                 TreeMap::new,
 48                 toCollection(() -> new TreeSet<>(comparing(OpSchema::since_version).reversed())
 49                 )));
 50     }
 51 
 52     static final String ONNX_PACKAGE = "oracle.code.onnx";
 53     static final String ONNX_OPERATORS_CLASS = "OnnxOperators";
 54 
 55     void genOpClass(Path dir) throws IOException {
 56         OutputStreamWriter osw = new OutputStreamWriter(
 57                 new FileOutputStream(dir.resolve(ONNX_OPERATORS_CLASS + ".java").toFile()));
 58         genOpClass(osw);
 59     }
 60 
 61     void genOpClass(Writer w_) throws IOException {
 62         IndentWriter w = new IndentWriter(w_);
 63 
 64         w.write("""
 65                 /*
 66                  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
 67                  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 68                  *
 69                  * This code is free software; you can redistribute it and/or modify it
 70                  * under the terms of the GNU General Public License version 2 only, as
 71                  * published by the Free Software Foundation.  Oracle designates this
 72                  * particular file as subject to the "Classpath" exception as provided
 73                  * by Oracle in the LICENSE file that accompanied this code.
 74                  *
 75                  * This code is distributed in the hope that it will be useful, but WITHOUT
 76                  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 77                  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 78                  * version 2 for more details (a copy is included in the LICENSE file that
 79                  * accompanied this code).
 80                  *
 81                  * You should have received a copy of the GNU General Public License version
 82                  * 2 along with this work; if not, write to the Free Software Foundation,
 83                  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 84                  *
 85                  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 86                  * or visit www.oracle.com if you need additional information or have any
 87                  * questions.
 88                  */
 89                 """);
 90         w.write("// Auto-generated from ONNX op schema\n");
 91         w.write("\n");
 92         w.write("package " + ONNX_PACKAGE + ";\n");
 93         w.write("\n");
 94         w.write("""
 95                 import oracle.code.onnx.ir.OnnxOps;
 96 
 97                 import java.util.Optional;
 98                 import java.util.List;
 99                 import java.util.Map;
100                 """);
101         w.write("\n");
102 
103         w.write("@SuppressWarnings({\"unchecked\", \"OptionalUsedAsFieldOrParameterType\"})\n");
104         w.write("public final class " + ONNX_OPERATORS_CLASS + " extends ExplicitOnnxOperators {\n");
105 
106         w.in();
107 
108         w.write("\n");
109         w.write("private " + ONNX_OPERATORS_CLASS + "() {}\n");
110         w.write("\n");
111 
112         for (OpSchema s : schemas.values().stream().map(SortedSet::getFirst).toList()) {
113             if (skip(s)) {
114                 System.out.println("Skipping " + s.name());
115                 continue;
116             }
117 
118             genMethod(w, s);
119             w.write("\n");
120         }
121         w.out();
122 
123         w.write("}\n");
124         w.flush();
125     }
126 
127     private boolean skip(OpSchema s) {
128         return s.attributes().stream().anyMatch(a ->
129                 a.type() == OpSchema.AttributeType.GRAPH ||
130                         a.type() == OpSchema.AttributeType.GRAPHS);
131     }
132 
133     private void genMethod(IndentWriter w, OpSchema s) throws IOException {
134         Map<String, ExternalizedCodeType> javaTypeConstraints = javaTypes(typeConstraintMap(s));
135         Set<String> javaTypeVariables = javaTypeVariables(javaTypeConstraints);
136 
137         boolean twoOrMoreResults = s.max_output() > 1 && s.outputs().size() > 1;
138         List<String> recordResultTypeVariables = new ArrayList<>();
139         if (twoOrMoreResults) {
140             for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
141                 if (s.outputs().stream().anyMatch(op -> typeConstraint.type_param_str().equals(op.type_str())) &&
142                         javaTypeVariables.contains(typeConstraint.type_param_str())) {
143                     recordResultTypeVariables.add(typeConstraint.type_param_str());
144                 }
145             }
146 
147             w.write("public record " + s.name() + "Result");
148             if (!recordResultTypeVariables.isEmpty()) {
149                 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
150             }
151             w.write("(");
152             boolean first = true;
153             for (OpSchema.FormalParameter outParam : s.outputs()) {
154                 if (!first) {
155                     w.write(", ");
156                 }
157 
158                 ExternalizedCodeType outputType = javaTypeConstraints
159                         .computeIfAbsent(outParam.type_str(),
160                                 ts -> javaType("?", parseTypeString(ts)));
161 
162                 w.write(outputType.toString());
163                 w.write(" " + outParam.name());
164 
165                 first = false;
166             }
167 
168             w.write(") { }\n");
169         }
170 
171         for (String dl : s.doc().split("\\n")) {
172             w.write("/// " + dl + "\n");
173         }
174         w.write("public static ");
175 
176         if (!s.type_constraints().isEmpty()) {
177             List<String> typeVariables = new ArrayList<>();
178             for (OpSchema.TypeConstraintParam typeConstraint : s.type_constraints()) {
179                 if (javaTypeVariables.contains(typeConstraint.type_param_str())) {
180                     typeVariables.add(typeConstraint.type_param_str());
181                 }
182             }
183 
184             if (!typeVariables.isEmpty()) {
185                 w.write(typeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
186                 w.write(" ");
187             }
188         }
189 
190         // @@@ Multiple output parameters - need to return tuple/record
191         final ExternalizedCodeType outputType;
192         if (s.min_output() == 1 && s.max_output() == 1) {
193             OpSchema.FormalParameter outParam = s.outputs().getFirst();
194 
195             outputType = javaTypeConstraints.computeIfAbsent(outParam.type_str(),
196                     ts -> javaType("?", parseTypeString(ts)));
197             w.write(outputType.toString());
198         } else if (s.min_output() == 0 && s.max_output() == 1) {
199             // This does not occur
200             throw new UnsupportedOperationException();
201         } else if (s.outputs().size() == 1) {
202             OpSchema.FormalParameter outParam = s.outputs().getFirst();
203             assert outParam.option() == OpSchema.FormalParameterOption.Variadic;
204 
205             outputType = new ExternalizedCodeType("List",
206                     List.of(javaTypeConstraints.computeIfAbsent(outParam.type_str(),
207                             ts -> javaType("?", parseTypeString(ts)))));
208             w.write(outputType.toString());
209         } else {
210             assert twoOrMoreResults;
211 
212             outputType = new ExternalizedCodeType(s.name() + "Result", List.of());
213             w.write(outputType.toString());
214             if (!recordResultTypeVariables.isEmpty()) {
215                 w.write(recordResultTypeVariables.stream().collect(Collectors.joining(", ", "<", ">")));
216             }
217         }
218         w.write(" ");
219         w.write(s.name() + "(");
220 
221         boolean first = true;
222         for (OpSchema.FormalParameter inParam : s.inputs()) {
223             if (!first) {
224                 w.write(", ");
225             }
226 
227             final ExternalizedCodeType inputType = javaTypeConstraints
228                     .computeIfAbsent(inParam.type_str(),
229                             ts -> javaType("?", parseTypeString(ts)));
230             switch (inParam.option()) {
231                 case Single -> {
232                     w.write(inputType.toString());
233                 }
234                 case Optional -> {
235                     w.write("Optional<");
236                     w.write(inputType.toString());
237                     w.write(">");
238                 }
239                 case Variadic -> {
240                     w.write("List<");
241                     w.write(inputType.toString());
242                     w.write(">");
243                 }
244             }
245             w.write(" ");
246             w.write(inParam.name());
247 
248             first = false;
249         }
250 
251         for (OpSchema.Attribute attribute : s.attributes()) {
252             if (!first) {
253                 w.write(", ");
254             }
255 
256             OpSchema.AttributeType aType = attribute.type();
257             String typeString = switch (aType) {
258                 default -> {
259                     if (attribute.required()) {
260                         yield aType.type().getSimpleName();
261                     } else {
262                         yield toBoxType(aType.type()).getSimpleName();
263                     }
264                 }
265             };
266             if (attribute.required()) {
267                 w.write(typeString);
268             } else {
269                 w.write("Optional<");
270                 w.write(typeString);
271                 w.write(">");
272             }
273             w.write(" ");
274             w.write(attribute.name());
275 
276             first = false;
277         }
278 
279         w.write(") {\n");
280         w.in();
281 
282         w.write("Object result = OnnxInterpreter.interpret(");
283         w.write("OnnxOps." + s.name() + ".class");
284         w.write(", ");
285 
286         w.write("List.of(");
287         first = true;
288         for (OpSchema.FormalParameter inParam : s.inputs()) {
289             if (!first) {
290                 w.write(", ");
291             }
292 
293             w.write(inParam.name());
294             first = false;
295         }
296         w.write(")");
297         w.write(", ");
298 
299         w.write("List.of(");
300         first = true;
301         for (OpSchema.Attribute attribute : s.attributes()) {
302             if (!first) {
303                 w.write(", ");
304             }
305 
306             w.write(attribute.name());
307             first = false;
308         }
309         w.write(")");
310         w.write(");\n");
311 
312         if (twoOrMoreResults) {
313             w.write("Object[] resultArray = (Object[]) result;\n");
314             w.write("return new " + s.name() + "Result");
315             if (!recordResultTypeVariables.isEmpty()) {
316                 w.write("<>");
317             }
318             w.write("(");
319             first = true;
320             for (int i = 0; i < s.outputs().size(); i++) {
321                 if (!first) {
322                     w.write(", ");
323                 }
324 
325                 w.write("(");
326                 //
327                 final ExternalizedCodeType t = javaTypeConstraints
328                         .computeIfAbsent(s.outputs().get(i).type_str(),
329                                 ts -> javaType("?", parseTypeString(ts)));
330                 w.write(t.toString());
331                 w.write(")");
332                 w.write("resultArray[" + i + "]");
333                 first = false;
334             }
335             w.write(");\n");
336         } else {
337             w.write("return (" + outputType + ") result;\n");
338         }
339         w.out();
340         w.write("}\n");
341     }
342 
343     static Map<String, ExternalizedCodeType> javaTypes(Map<String, ExternalizedCodeType> tcm) {
344         return tcm.entrySet().stream().collect(Collectors.toMap(
345                 e -> e.getKey(),
346                 e -> javaType(e.getKey(), e.getValue())));
347     }
348 
349     static Set<String> javaTypeVariables(Map<String, ExternalizedCodeType> tcm) {
350         return tcm.entrySet().stream()
351                 .filter(e -> usesTypeVariable(e.getKey(), e.getValue()))
352                 .map(e -> e.getKey())
353                 .collect(Collectors.toSet());
354     }
355 
356     static boolean usesTypeVariable(String typeVariable, ExternalizedCodeType ete) {
357         if (ete.arguments().isEmpty()) {
358             return typeVariable.equals(ete.identifier());
359         }
360         return ete.arguments().stream().anyMatch(a -> usesTypeVariable(typeVariable, a));
361     }
362 
363     static ExternalizedCodeType javaType(String typeVariable, ExternalizedCodeType ete) {
364         String javaIdentifier = switch (ete.identifier()) {
365             case "seq" -> "List";
366             case "sequence" -> "List";
367             case "map" -> "Map";
368             case "optional" -> "Optional";
369             case "tensor" -> "Tensor";
370             case "?" -> typeVariable;
371 
372             default -> {
373                 Tensor.ElementType elementType = Tensor.ElementType.fromOnnxName(ete.identifier());
374                 Class<?> type = elementType.type();
375                 if (type.isPrimitive()) {
376                     yield toBoxType(type).getSimpleName();
377                 } else {
378                     yield type.getSimpleName();
379                 }
380             }
381         };
382 
383         if (ete.identifier().equals("map") &&
384                 ete.arguments().stream().allMatch(t -> t.identifier().equals("?"))) {
385             return new ExternalizedCodeType(javaIdentifier,
386                     ete.arguments().stream().map(c -> javaType("?", c)).toList());
387         }
388 
389         return new ExternalizedCodeType(javaIdentifier,
390                 ete.arguments().stream().map(c -> javaType(typeVariable, c)).toList());
391     }
392 
393     static Map<String, ExternalizedCodeType> typeConstraintMap(OpSchema s) {
394         return s.type_constraints().stream().collect(toMap(
395                 tc -> tc.type_param_str(),
396                 tc -> tc.allowed_type_strs().stream().map(OperatorGen::parseTypeString).reduce(OperatorGen::lub).orElseThrow()));
397     }
398 
399     static ExternalizedCodeType lub(ExternalizedCodeType a,
400                                                    ExternalizedCodeType b) {
401         if (!a.identifier().equals(b.identifier())) {
402             return new ExternalizedCodeType("?", List.of());
403         }
404 
405         assert a.arguments().size() == b.arguments().size();
406 
407         List<ExternalizedCodeType> children = new ArrayList<>();
408         for (int i = 0; i < a.arguments().size(); i++) {
409             children.add(lub(a.arguments().get(i), b.arguments().get(i)));
410         }
411 
412         return new ExternalizedCodeType(a.identifier(), children);
413     }
414 
415     static ExternalizedCodeType parseTypeString(String type_str) {
416         return ExternalizedCodeType.ofString(
417                 type_str.replace('(', '<').replace(')', '>'));
418     }
419 
420     static Class<?> toBoxType(Class<?> pc) {
421         if (pc == byte.class) {
422             return Byte.class;
423         } else if (pc == short.class) {
424             return Short.class;
425         } else if (pc == int.class) {
426             return Integer.class;
427         } else if (pc == long.class) {
428             return Long.class;
429         } else if (pc == float.class) {
430             return Float.class;
431         } else if (pc == double.class) {
432             return Double.class;
433         } else if (pc == boolean.class) {
434             return Boolean.class;
435         } else {
436             return pc;
437         }
438     }
439 
440     public static void main(String[] args) throws Exception {
441         List<OpSchema> schemas = OpSchemaParser.parse(Path.of(
442                 "opgen/onnx-schema.json"));
443         OperatorGen oprGen = new OperatorGen(schemas);
444 
445         oprGen.genOpClass(Path.of("src/main/java/oracle/code/onnx"));
446     }
447 }