1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx.proto;
 25 
 26 import java.io.InputStream;
 27 import java.io.RandomAccessFile;
 28 import java.lang.foreign.Arena;
 29 import java.lang.foreign.ValueLayout;
 30 import java.lang.reflect.AccessFlag;
 31 import java.lang.reflect.Field;
 32 import java.nio.channels.FileChannel;
 33 import java.util.ArrayList;
 34 import java.util.Arrays;
 35 import java.util.HashMap;
 36 import java.util.Iterator;
 37 import java.util.LinkedHashMap;
 38 import java.util.List;
 39 import java.util.Map;
 40 import java.util.function.Function;
 41 import jdk.incubator.code.Block;
 42 import jdk.incubator.code.CodeItem;
 43 import jdk.incubator.code.Op;
 44 import jdk.incubator.code.TypeElement;
 45 import jdk.incubator.code.Value;
 46 import jdk.incubator.code.extern.ExternalizedOp;
 47 import jdk.incubator.code.extern.OpFactory;
 48 import jdk.incubator.code.dialect.core.CoreOp;
 49 import jdk.incubator.code.dialect.core.CoreType;
 50 import jdk.incubator.code.dialect.core.FunctionType;
 51 import jdk.incubator.code.extern.OpWriter;
 52 import oracle.code.onnx.CNNTest;
 53 import oracle.code.onnx.OnnxRuntime;
 54 import oracle.code.onnx.Tensor;
 55 import oracle.code.onnx.ir.*;
 56 import org.junit.jupiter.api.Test;
 57 
 58 
 59 public class OnnxModelTest {
 60 
 61     static OnnxType toOnnxType(OnnxModel.TypeProto tp) {
 62         if (tp.tensorType() instanceof OnnxModel.TypeProto.Tensor t) {
 63             return toTensorType(t.elemType());
 64         } else if (tp.optionalType() instanceof OnnxModel.TypeProto.Optional o) {
 65             return OnnxType.optional(toOnnxType(o.elemType()));
 66         } else if (tp.sequenceType()  instanceof OnnxModel.TypeProto.Sequence s) {
 67             return OnnxType.seq(toOnnxType(s.elemType()));
 68         } else if (tp.mapType() instanceof OnnxModel.TypeProto.Map m) {
 69             return OnnxType.map(toKeyType(m.keyType()), toOnnxType(m.valueType()));
 70         } else if (tp.sparseTensorType() instanceof OnnxModel.TypeProto.SparseTensor st) {
 71             throw new UnsupportedOperationException("Sparse tensors not supported yet."); // @@@
 72         }
 73         throw new IllegalArgumentException("No type specified.");
 74     }
 75 
 76     static FunctionType toFunctionType(OnnxModel.GraphProto g) {
 77         var paramTypes = new ArrayList<TypeElement>();
 78         for (OnnxModel.ValueInfoProto input : g.input()) {
 79             paramTypes.add(toOnnxType(input.type()));
 80         }
 81         for (OnnxModel.TensorProto init : g.initializer()) {
 82             paramTypes.add(toTensorType(init.dataType()));
 83         }
 84         var returnType = g.output().size() == 1
 85                 ? toOnnxType(g.output().getFirst().type())
 86                 : CoreType.tupleType(g.output().stream().map(OnnxModel.ValueInfoProto::type).map(OnnxModelTest::toOnnxType).toList());
 87         return CoreType.functionType(returnType, paramTypes);
 88     }
 89 
 90     static OnnxType toKeyType(int kt) {
 91         return switch (kt) {
 92             case 2 -> OnnxType.UINT8;
 93             case 3 -> OnnxType.INT8;
 94             case 4 -> OnnxType.UINT16;
 95             case 5 -> OnnxType.INT16;
 96             case 6 -> OnnxType.INT32;
 97             case 7 -> OnnxType.INT64;
 98             case 8 -> OnnxType.STRING;
 99             case 12 -> OnnxType.UINT32;
100             case 13 -> OnnxType.UINT64;
101             default -> throw new IllegalArgumentException("Invalid key type: " + kt);
102         };
103     }
104 
105     static OnnxType.TensorType toTensorType(int tt) {
106         return switch (tt) {
107             case 1 -> OnnxType.TENSOR_FLOAT32;
108             case 2 -> OnnxType.TENSOR_UINT8;
109             case 3 -> OnnxType.TENSOR_INT8;
110             case 4 -> OnnxType.TENSOR_UINT16;
111             case 5 -> OnnxType.TENSOR_INT16;
112             case 6 -> OnnxType.TENSOR_INT32;
113             case 7 -> OnnxType.TENSOR_INT64;
114             case 8 -> OnnxType.TENSOR_STRING;
115             case 9 -> OnnxType.TENSOR_BOOL;
116             case 10 -> OnnxType.TENSOR_FLOAT16;
117             case 11 -> OnnxType.TENSOR_FLOAT64;
118             case 12 -> OnnxType.TENSOR_UINT32;
119             case 13 -> OnnxType.TENSOR_UINT64;
120             case 14 -> OnnxType.TENSOR_COMPLEX64;
121             case 15 -> OnnxType.TENSOR_COMPLEX128;
122             case 16 -> OnnxType.TENSOR_BFLOAT16;
123             case 17 -> OnnxType.TENSOR_FLOAT8E4M3FN;
124             case 18 -> OnnxType.TENSOR_FLOAT8E4M3FNUZ;
125             case 19 -> OnnxType.TENSOR_FLOAT8E5M2;
126             case 20 -> OnnxType.TENSOR_FLOAT8E5M2FNUZ;
127             case 21 -> OnnxType.TENSOR_UINT4;
128             case 22 -> OnnxType.TENSOR_INT4;
129             case 23 -> OnnxType.TENSOR_FLOAT4E2M1;
130             default -> OnnxType.tensor(null);
131         };
132     }
133 
134     static final OpFactory ONNX_OP_FACTORY = OpFactoryHelper.OP_FACTORY.get(ExplicitOnnxOps.class)
135             .andThen(OpFactoryHelper.OP_FACTORY.get(OnnxOps.class));
136 
137     static final Map<String, OnnxOp.OnnxSchema> ONNX_SCHEMA_REGISTRY = collectSchemas(ExplicitOnnxOps.class, OnnxOps.class);
138 
139     private static Map<String, OnnxOp.OnnxSchema> collectSchemas(Class<?>... cls) {
140         Map<String, OnnxOp.OnnxSchema> reg = new HashMap<>();
141         for (Class<?> c : cls) {
142             for (Class<?> nm : c.getNestMembers()) {
143                 for (Field f : nm.getFields()) {
144                     if (f.accessFlags().contains(AccessFlag.STATIC) && OnnxOp.OnnxSchema.class.isAssignableFrom(f.getType())) try {
145                         OnnxOp.OnnxSchema sch = (OnnxOp.OnnxSchema)f.get(null);
146                         reg.put(sch.name(), sch);
147                     } catch (ReflectiveOperationException e) {
148                         throw new RuntimeException(e);
149                     }
150                 }
151             }
152         }
153         return reg;
154     }
155 
156     record OpWithNames<T extends Op> (T op, List<String> names) {
157 
158         private Function<CodeItem, String> namer() {
159             var defNamer = OpWriter.CodeItemNamerOption.defaultValue().namer();
160             var namer = new HashMap<Value, Integer>();
161             return ci -> ci instanceof Value v ? names.get(namer.computeIfAbsent(v, _ -> namer.size())) : defNamer.apply(ci);
162         }
163 
164         public String toText() {
165             return OpWriter.toText(op, OpWriter.CodeItemNamerOption.of(namer()));
166         }
167     }
168 
169     static OpWithNames<CoreOp.FuncOp> toFuncOp(OnnxModel.GraphProto g) {
170         var valueMap = new LinkedHashMap<String, Value>();
171         var func = CoreOp.FuncOp.func(g.name(), toFunctionType(g)).body(fb -> {
172 
173             { // fill value map for parameters and initializers
174                 Iterator<Block.Parameter> params = fb.entryBlock().parameters().iterator();
175                 for (OnnxModel.ValueInfoProto input : g.input()) {
176                     valueMap.put(input.name(), params.next());
177                 }
178                 for (OnnxModel.TensorProto init : g.initializer()) {
179                     valueMap.put(init.name(), params.next());
180                 }
181             }
182 
183             for (OnnxModel.NodeProto n : g.node()) {
184                 String opType = n.opType();
185 
186                  // @@@ an old alias ? could not find the spec
187                 if (opType.equals("SimplifiedLayerNormalization")) {
188                     opType = "LayerNormalization";
189                 }
190 
191                 if (n.domain() != null && !n.domain().isEmpty() && !n.domain().equals("ai.onnx")) {
192                     opType = n.domain() + "." + opType;
193                 }
194 
195                 OnnxOp.OnnxSchema schema = ONNX_SCHEMA_REGISTRY.computeIfAbsent(opType, ot -> {throw new IllegalArgumentException("Unknown op type: " + ot);});
196                 Map<String, Object> attributes = new LinkedHashMap<>();
197                 if (n.attribute() != null) {
198                     for (OnnxModel.AttributeProto a : n.attribute()) {
199                         attributes.put(a.name(), toAttributeValue(a));
200                     }
201                 }
202 
203                 // map inputs
204                 List<Value> inputs = new ArrayList<>();
205                 if (n.input() != null) {
206                     List<OnnxOp.OnnxParameter> optionalInputs = new ArrayList<>();
207                     for (int i = 0; i < n.input().size(); i++) {
208                         OnnxOp.OnnxParameter param = schema.inputs().get(i);
209                         Value v = valueMap.get(n.input().get(i));
210                         if (v != null) {
211                             switch (param.quantifier()) {
212                                 case REQUIRED -> {
213                                     inputs.add(v);
214                                 }
215                                 case OPTIONAL -> {
216                                     optionalInputs.add(param);
217                                     inputs.add(v);
218                                 }
219                                 case VARIADIC -> {
220                                     inputs.add(v); // @@@ accumulate variadic inputs into
221                                 }
222                             }
223                         }
224                     }
225                     if (!optionalInputs.isEmpty()) {
226                         attributes.put("optional_inputs", optionalInputs);
227                     }
228                 }
229 
230                 // map outputs
231                 List<OnnxOp.OnnxParameter> optionalOutputs = new ArrayList<>();
232                 List<String> outputNames = new ArrayList<>();
233                 if (n.output() != null) {
234                     for (int i = 0; i < n.output().size(); i++) {
235                         OnnxOp.OnnxParameter param = schema.outputs().get(i);
236                         if (!n.output().get(i).isEmpty()) {
237                             outputNames.add(n.output().get(i));
238                             if (param.quantifier() == OnnxOp.OnnxParameter.Quantifier.OPTIONAL) {
239                                 optionalOutputs.add(param);
240                             }
241                         }
242                     }
243                     if (!optionalOutputs.isEmpty()) {
244                         attributes.put("optional_outputs", optionalOutputs);
245                     }
246                 }
247 
248                 // inline Constant op tensor attribute as value
249                 if (opType.equals("Constant") && attributes.remove(OnnxOps.Constant.Attribute.value.name()) instanceof Tensor t) {
250                     switch (t.shape().length) {
251                         case 0 -> { // scalar
252                             switch (t.elementType()) {
253                                 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_float.name(), t.data().get(ValueLayout.JAVA_FLOAT, 0));
254                                 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_int.name(), t.data().get(ValueLayout.JAVA_LONG, 0));
255                                 default -> throw new UnsupportedOperationException();
256                             }
257                         }
258                         case 1 -> { // 1d tensor
259                             switch (t.elementType()) {
260                                 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_floats.name(), t.data().toArray(ValueLayout.JAVA_FLOAT));
261                                 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_ints.name(), t.data().toArray(ValueLayout.JAVA_LONG));
262                                 default -> throw new UnsupportedOperationException();
263                             }
264                         }
265                         default -> throw new UnsupportedOperationException();
266                     }
267                 }
268 
269                 // get the op
270                 ExternalizedOp extOp = new ExternalizedOp(
271                         opType,
272                         null,
273                         inputs,
274                         List.of(),
275                         new OnnxType.TensorType(null),
276                         attributes,
277                         List.of());
278                 OnnxOp rawOp = (OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp);
279 
280                 // patch the op return type
281                 TypeElement returnType = rawOp.onnxOutputs().size() == 1
282                         ? inferTypeVariableType(rawOp.onnxOutputs().getFirst().type(), rawOp, n)
283                         : CoreType.tupleType(rawOp.onnxOutputs().stream().map(o -> inferTypeVariableType(o.type(), rawOp, n)).toList());
284                 extOp = new ExternalizedOp(
285                         extOp.name(),
286                         null,
287                         extOp.operands(),
288                         extOp.successors(),
289                         returnType,
290                         extOp.attributes(),
291                         extOp.bodyDefinitions());
292                 Op.Result res = fb.op((OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp));
293 
294                 // map outputs
295                 if (outputNames.size() == 1) {
296                     valueMap.put(n.output().getFirst(), res);
297                 } else {
298                     valueMap.put(n.name(), res);
299                     for (int i = 0; i < outputNames.size(); i++) {
300                         valueMap.put(outputNames.get(i), fb.op(CoreOp.tupleLoad(res, i)));
301                     }
302                 }
303             }
304 
305             if (g.output().size() == 1) {
306                 fb.op(CoreOp.return_(valueMap.get(g.output().getFirst().name())));
307             } else {
308                 Op.Result ret = fb.op(CoreOp.tuple(g.output().stream().map(OnnxModel.ValueInfoProto::name).map(valueMap::get).toList()));
309                 valueMap.put(g.name() + "_return", ret);
310                 fb.op(CoreOp.return_(ret));
311             }
312         });
313 
314         return new OpWithNames<>(func, List.of(valueMap.sequencedKeySet().toArray(String[]::new)));
315     }
316 
317     static OnnxType inferTypeVariableType(OnnxType type, OnnxOp op, OnnxModel.NodeProto n) {
318         if (type instanceof OnnxType.TypeVariable tv) {
319             if (tv.types().size() == 1) {
320                 return tv.types().getFirst();
321             }
322             // search for the same type variable across inputs
323             for (var ie : op.onnxInputs().entrySet()) {
324                 if (ie.getKey().type().equals(tv)) {
325                     if (ie.getValue() instanceof Value v && v.type() instanceof OnnxType ot) {
326                         return ot;
327                     } else if (ie.getValue() instanceof List l && !l.isEmpty() && l.getFirst() instanceof Value v && v.type() instanceof OnnxType ot) {
328                         return ot;
329                     }
330                 }
331             }
332 
333             // special cases
334             return switch (op) {
335                 case OnnxOps.Cast c ->
336                     toTensorType((int)c.to());
337                 case OnnxOps.ConstantOfShape _, OnnxOps.Constant _-> // get tensor type from tensor attribute
338                     n.attribute() != null
339                     && !n.attribute().isEmpty()
340                     && n.attribute().getFirst().t() instanceof OnnxModel.TensorProto tp
341                             ? toTensorType(tp.dataType())
342                             : OnnxType.TENSOR_FLOAT32; // default
343                 default ->
344                     throw new IllegalArgumentException("Could not infer op type for: " + op.toText());
345             };
346         }
347         return type;
348     }
349 
350     static Object toAttributeValue(OnnxModel.AttributeProto a) {
351         return switch (a.type()) {
352             case FLOAT -> a.f();
353             case INT -> a.i();
354             case STRING -> a.s();
355             case TENSOR -> toTensor(a.t());
356 //    GRAPH = 5;
357 //    SPARSE_TENSOR = 11;
358 //    TYPE_PROTO = 13;
359             case FLOATS -> joinFloatArray(a.floats());
360             case INTS -> joinLongArray(a.ints());
361             case STRINGS -> a.strings();
362             case TENSORS -> a.tensors().stream().map(OnnxModelTest::toTensor).toArray(Tensor[]::new);
363 //    GRAPHS = 10;
364 //    SPARSE_TENSORS = 12;
365 //    TYPE_PROTOS = 14;
366             default -> throw new UnsupportedOperationException("Unsupported " + a.type());
367         };
368     }
369 
370     static Tensor toTensor(OnnxModel.TensorProto tensorProto) {
371         // @@@ floatData, longData, stringData...
372         // @@@ externalData
373         // @@@ segments
374         return Tensor.ofShape(joinLongArray(tensorProto.dims()), tensorProto.rawData(), Tensor.ElementType.fromOnnxId(tensorProto.dataType()));
375     }
376 
377     static float[] joinFloatArray(List<float[]> floats) {
378         if (floats == null) return new float[0];
379         float[] join = new float[floats.stream().mapToInt(f -> f.length).sum()];
380         int i = 0;
381         for (float[] f : floats) {
382             System.arraycopy(f, 0, join, i, f.length);
383             i += f.length;
384         }
385         return join;
386     }
387 
388     static long[] joinLongArray(List<long[]> longs) {
389         if (longs == null) return new long[0];
390         long[] join = new long[longs.stream().mapToInt(f -> f.length).sum()];
391         int i = 0;
392         for (long[] f : longs) {
393             System.arraycopy(f, 0, join, i, f.length);
394             i += f.length;
395         }
396         return join;
397     }
398 
399     @Test
400     public void cnnLiftTest() throws Exception {
401         try (InputStream in = CNNTest.class.getResourceAsStream("lenet-torchscript.onnx")) {
402 
403             // parse onnx protobuf model
404             OnnxModel.ModelProto protoModel = OnnxModel.readFrom(in.readAllBytes());
405 
406 //            System.out.println(model.toText());
407 
408             // lift the cnnFuncOp from Onnx protobuf model
409             OpWithNames<CoreOp.FuncOp> cnnFuncOp = toFuncOp(protoModel.graph());
410 
411             System.out.println(cnnFuncOp.toText());
412 //            System.out.println(cnnFuncOp.op().toText());
413 
414             // test the lifted model
415             try (Arena a = Arena.ofConfined()) {
416                 List<Tensor> inputValues = new ArrayList<>();
417 
418                 // initializers are extracted from the proto model directly
419                 for (OnnxModel.TensorProto init : protoModel.graph().initializer()) {
420                     inputValues.add(Tensor.ofShape(a, joinLongArray(init.dims()), init.rawData(), Tensor.ElementType.fromOnnxId(init.dataType())));
421                 }
422 
423                 // fake image
424                 float[] image = new float[28 * 28];
425                 for (int i = 13; i < 28 * 28; i+=28) {
426                     image[i] = 1f;
427                 }
428                 inputValues.add(Tensor.ofShape(a, new long[] {1, 1, 28, 28}, image));
429 
430                 // run
431                 List<Tensor> res = OnnxRuntime.getInstance().run(a, cnnFuncOp.op().body().entryBlock(), inputValues, inputValues.size() - 1);
432 
433                 System.out.println(Arrays.toString(res.getFirst().data().toArray(ValueLayout.JAVA_FLOAT)));
434             }
435         }
436     }
437 
438     public static void main(String[] args) throws Exception {
439         for (var fName : args) {
440             try (var in = new RandomAccessFile(fName, "r")) {
441                 OnnxModel.ModelProto model = OnnxModel.readFrom(in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()));
442                 System.out.println(model.toText());
443                 var liftedModel = toFuncOp(model.graph());
444                 System.out.println(liftedModel.toText());
445            }
446         }
447     }
448 }