1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx.proto;
25
26 import java.io.InputStream;
27 import java.io.RandomAccessFile;
28 import java.lang.foreign.Arena;
29 import java.lang.foreign.ValueLayout;
30 import java.lang.reflect.AccessFlag;
31 import java.lang.reflect.Field;
32 import java.nio.channels.FileChannel;
33 import java.util.ArrayList;
34 import java.util.Arrays;
35 import java.util.HashMap;
36 import java.util.Iterator;
37 import java.util.LinkedHashMap;
38 import java.util.List;
39 import java.util.Map;
40 import java.util.function.Function;
41 import jdk.incubator.code.Block;
42 import jdk.incubator.code.CodeItem;
43 import jdk.incubator.code.Op;
44 import jdk.incubator.code.TypeElement;
45 import jdk.incubator.code.Value;
46 import jdk.incubator.code.extern.ExternalizedOp;
47 import jdk.incubator.code.extern.OpFactory;
48 import jdk.incubator.code.dialect.core.CoreOp;
49 import jdk.incubator.code.dialect.core.CoreType;
50 import jdk.incubator.code.dialect.core.FunctionType;
51 import jdk.incubator.code.extern.OpWriter;
52 import oracle.code.onnx.CNNTest;
53 import oracle.code.onnx.OnnxRuntime;
54 import oracle.code.onnx.Tensor;
55 import oracle.code.onnx.ir.*;
56 import org.junit.jupiter.api.Test;
57
58
59 public class OnnxModelTest {
60
61 static OnnxType toOnnxType(OnnxModel.TypeProto tp) {
62 if (tp.tensorType() instanceof OnnxModel.TypeProto.Tensor t) {
63 return toTensorType(t.elemType());
64 } else if (tp.optionalType() instanceof OnnxModel.TypeProto.Optional o) {
65 return OnnxType.optional(toOnnxType(o.elemType()));
66 } else if (tp.sequenceType() instanceof OnnxModel.TypeProto.Sequence s) {
67 return OnnxType.seq(toOnnxType(s.elemType()));
68 } else if (tp.mapType() instanceof OnnxModel.TypeProto.Map m) {
69 return OnnxType.map(toKeyType(m.keyType()), toOnnxType(m.valueType()));
70 } else if (tp.sparseTensorType() instanceof OnnxModel.TypeProto.SparseTensor st) {
71 throw new UnsupportedOperationException("Sparse tensors not supported yet."); // @@@
72 }
73 throw new IllegalArgumentException("No type specified.");
74 }
75
76 static FunctionType toFunctionType(OnnxModel.GraphProto g) {
77 var paramTypes = new ArrayList<TypeElement>();
78 for (OnnxModel.ValueInfoProto input : g.input()) {
79 paramTypes.add(toOnnxType(input.type()));
80 }
81 for (OnnxModel.TensorProto init : g.initializer()) {
82 paramTypes.add(toTensorType(init.dataType()));
83 }
84 var returnType = g.output().size() == 1
85 ? toOnnxType(g.output().getFirst().type())
86 : CoreType.tupleType(g.output().stream().map(OnnxModel.ValueInfoProto::type).map(OnnxModelTest::toOnnxType).toList());
87 return CoreType.functionType(returnType, paramTypes);
88 }
89
90 static OnnxType toKeyType(int kt) {
91 return switch (kt) {
92 case 2 -> OnnxType.UINT8;
93 case 3 -> OnnxType.INT8;
94 case 4 -> OnnxType.UINT16;
95 case 5 -> OnnxType.INT16;
96 case 6 -> OnnxType.INT32;
97 case 7 -> OnnxType.INT64;
98 case 8 -> OnnxType.STRING;
99 case 12 -> OnnxType.UINT32;
100 case 13 -> OnnxType.UINT64;
101 default -> throw new IllegalArgumentException("Invalid key type: " + kt);
102 };
103 }
104
105 static OnnxType.TensorType toTensorType(int tt) {
106 return switch (tt) {
107 case 1 -> OnnxType.TENSOR_FLOAT32;
108 case 2 -> OnnxType.TENSOR_UINT8;
109 case 3 -> OnnxType.TENSOR_INT8;
110 case 4 -> OnnxType.TENSOR_UINT16;
111 case 5 -> OnnxType.TENSOR_INT16;
112 case 6 -> OnnxType.TENSOR_INT32;
113 case 7 -> OnnxType.TENSOR_INT64;
114 case 8 -> OnnxType.TENSOR_STRING;
115 case 9 -> OnnxType.TENSOR_BOOL;
116 case 10 -> OnnxType.TENSOR_FLOAT16;
117 case 11 -> OnnxType.TENSOR_FLOAT64;
118 case 12 -> OnnxType.TENSOR_UINT32;
119 case 13 -> OnnxType.TENSOR_UINT64;
120 case 14 -> OnnxType.TENSOR_COMPLEX64;
121 case 15 -> OnnxType.TENSOR_COMPLEX128;
122 case 16 -> OnnxType.TENSOR_BFLOAT16;
123 case 17 -> OnnxType.TENSOR_FLOAT8E4M3FN;
124 case 18 -> OnnxType.TENSOR_FLOAT8E4M3FNUZ;
125 case 19 -> OnnxType.TENSOR_FLOAT8E5M2;
126 case 20 -> OnnxType.TENSOR_FLOAT8E5M2FNUZ;
127 case 21 -> OnnxType.TENSOR_UINT4;
128 case 22 -> OnnxType.TENSOR_INT4;
129 case 23 -> OnnxType.TENSOR_FLOAT4E2M1;
130 default -> OnnxType.tensor(null);
131 };
132 }
133
134 static final OpFactory ONNX_OP_FACTORY = OpFactoryHelper.OP_FACTORY.get(ExplicitOnnxOps.class)
135 .andThen(OpFactoryHelper.OP_FACTORY.get(OnnxOps.class));
136
137 static final Map<String, OnnxOp.OnnxSchema> ONNX_SCHEMA_REGISTRY = collectSchemas(ExplicitOnnxOps.class, OnnxOps.class);
138
139 private static Map<String, OnnxOp.OnnxSchema> collectSchemas(Class<?>... cls) {
140 Map<String, OnnxOp.OnnxSchema> reg = new HashMap<>();
141 for (Class<?> c : cls) {
142 for (Class<?> nm : c.getNestMembers()) {
143 for (Field f : nm.getFields()) {
144 if (f.accessFlags().contains(AccessFlag.STATIC) && OnnxOp.OnnxSchema.class.isAssignableFrom(f.getType())) try {
145 OnnxOp.OnnxSchema sch = (OnnxOp.OnnxSchema)f.get(null);
146 reg.put(sch.name(), sch);
147 } catch (ReflectiveOperationException e) {
148 throw new RuntimeException(e);
149 }
150 }
151 }
152 }
153 return reg;
154 }
155
156 record OpWithNames<T extends Op> (T op, List<String> names) {
157
158 private Function<CodeItem, String> namer() {
159 var defNamer = OpWriter.CodeItemNamerOption.defaultValue().namer();
160 var namer = new HashMap<Value, Integer>();
161 return ci -> ci instanceof Value v ? names.get(namer.computeIfAbsent(v, _ -> namer.size())) : defNamer.apply(ci);
162 }
163
164 public String toText() {
165 return OpWriter.toText(op, OpWriter.CodeItemNamerOption.of(namer()));
166 }
167 }
168
169 static OpWithNames<CoreOp.FuncOp> toFuncOp(OnnxModel.GraphProto g) {
170 var valueMap = new LinkedHashMap<String, Value>();
171 var func = CoreOp.FuncOp.func(g.name(), toFunctionType(g)).body(fb -> {
172
173 { // fill value map for parameters and initializers
174 Iterator<Block.Parameter> params = fb.entryBlock().parameters().iterator();
175 for (OnnxModel.ValueInfoProto input : g.input()) {
176 valueMap.put(input.name(), params.next());
177 }
178 for (OnnxModel.TensorProto init : g.initializer()) {
179 valueMap.put(init.name(), params.next());
180 }
181 }
182
183 for (OnnxModel.NodeProto n : g.node()) {
184 String opType = n.opType();
185
186 // @@@ an old alias ? could not find the spec
187 if (opType.equals("SimplifiedLayerNormalization")) {
188 opType = "LayerNormalization";
189 }
190
191 if (n.domain() != null && !n.domain().isEmpty() && !n.domain().equals("ai.onnx")) {
192 opType = n.domain() + "." + opType;
193 }
194
195 OnnxOp.OnnxSchema schema = ONNX_SCHEMA_REGISTRY.computeIfAbsent(opType, ot -> {throw new IllegalArgumentException("Unknown op type: " + ot);});
196 Map<String, Object> attributes = new LinkedHashMap<>();
197 if (n.attribute() != null) {
198 for (OnnxModel.AttributeProto a : n.attribute()) {
199 attributes.put(a.name(), toAttributeValue(a));
200 }
201 }
202
203 // map inputs
204 List<Value> inputs = new ArrayList<>();
205 if (n.input() != null) {
206 List<OnnxOp.OnnxParameter> optionalInputs = new ArrayList<>();
207 for (int i = 0; i < n.input().size(); i++) {
208 OnnxOp.OnnxParameter param = schema.inputs().get(i);
209 Value v = valueMap.get(n.input().get(i));
210 if (v != null) {
211 switch (param.quantifier()) {
212 case REQUIRED -> {
213 inputs.add(v);
214 }
215 case OPTIONAL -> {
216 optionalInputs.add(param);
217 inputs.add(v);
218 }
219 case VARIADIC -> {
220 inputs.add(v); // @@@ accumulate variadic inputs into
221 }
222 }
223 }
224 }
225 if (!optionalInputs.isEmpty()) {
226 attributes.put("optional_inputs", optionalInputs);
227 }
228 }
229
230 // map outputs
231 List<OnnxOp.OnnxParameter> optionalOutputs = new ArrayList<>();
232 List<String> outputNames = new ArrayList<>();
233 if (n.output() != null) {
234 for (int i = 0; i < n.output().size(); i++) {
235 OnnxOp.OnnxParameter param = schema.outputs().get(i);
236 if (!n.output().get(i).isEmpty()) {
237 outputNames.add(n.output().get(i));
238 if (param.quantifier() == OnnxOp.OnnxParameter.Quantifier.OPTIONAL) {
239 optionalOutputs.add(param);
240 }
241 }
242 }
243 if (!optionalOutputs.isEmpty()) {
244 attributes.put("optional_outputs", optionalOutputs);
245 }
246 }
247
248 // inline Constant op tensor attribute as value
249 if (opType.equals("Constant") && attributes.remove(OnnxOps.Constant.Attribute.value.name()) instanceof Tensor t) {
250 switch (t.shape().length) {
251 case 0 -> { // scalar
252 switch (t.elementType()) {
253 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_float.name(), t.data().get(ValueLayout.JAVA_FLOAT, 0));
254 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_int.name(), t.data().get(ValueLayout.JAVA_LONG, 0));
255 default -> throw new UnsupportedOperationException();
256 }
257 }
258 case 1 -> { // 1d tensor
259 switch (t.elementType()) {
260 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_floats.name(), t.data().toArray(ValueLayout.JAVA_FLOAT));
261 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_ints.name(), t.data().toArray(ValueLayout.JAVA_LONG));
262 default -> throw new UnsupportedOperationException();
263 }
264 }
265 default -> throw new UnsupportedOperationException();
266 }
267 }
268
269 // get the op
270 ExternalizedOp extOp = new ExternalizedOp(
271 opType,
272 null,
273 inputs,
274 List.of(),
275 new OnnxType.TensorType(null),
276 attributes,
277 List.of());
278 OnnxOp rawOp = (OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp);
279
280 // patch the op return type
281 TypeElement returnType = rawOp.onnxOutputs().size() == 1
282 ? inferTypeVariableType(rawOp.onnxOutputs().getFirst().type(), rawOp, n)
283 : CoreType.tupleType(rawOp.onnxOutputs().stream().map(o -> inferTypeVariableType(o.type(), rawOp, n)).toList());
284 extOp = new ExternalizedOp(
285 extOp.name(),
286 null,
287 extOp.operands(),
288 extOp.successors(),
289 returnType,
290 extOp.attributes(),
291 extOp.bodyDefinitions());
292 Op.Result res = fb.op((OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp));
293
294 // map outputs
295 if (outputNames.size() == 1) {
296 valueMap.put(n.output().getFirst(), res);
297 } else {
298 valueMap.put(n.name(), res);
299 for (int i = 0; i < outputNames.size(); i++) {
300 valueMap.put(outputNames.get(i), fb.op(CoreOp.tupleLoad(res, i)));
301 }
302 }
303 }
304
305 if (g.output().size() == 1) {
306 fb.op(CoreOp.return_(valueMap.get(g.output().getFirst().name())));
307 } else {
308 Op.Result ret = fb.op(CoreOp.tuple(g.output().stream().map(OnnxModel.ValueInfoProto::name).map(valueMap::get).toList()));
309 valueMap.put(g.name() + "_return", ret);
310 fb.op(CoreOp.return_(ret));
311 }
312 });
313
314 return new OpWithNames<>(func, List.of(valueMap.sequencedKeySet().toArray(String[]::new)));
315 }
316
317 static OnnxType inferTypeVariableType(OnnxType type, OnnxOp op, OnnxModel.NodeProto n) {
318 if (type instanceof OnnxType.TypeVariable tv) {
319 if (tv.types().size() == 1) {
320 return tv.types().getFirst();
321 }
322 // search for the same type variable across inputs
323 for (var ie : op.onnxInputs().entrySet()) {
324 if (ie.getKey().type().equals(tv)) {
325 if (ie.getValue() instanceof Value v && v.type() instanceof OnnxType ot) {
326 return ot;
327 } else if (ie.getValue() instanceof List l && !l.isEmpty() && l.getFirst() instanceof Value v && v.type() instanceof OnnxType ot) {
328 return ot;
329 }
330 }
331 }
332
333 // special cases
334 return switch (op) {
335 case OnnxOps.Cast c ->
336 toTensorType((int)c.to());
337 case OnnxOps.ConstantOfShape _, OnnxOps.Constant _-> // get tensor type from tensor attribute
338 n.attribute() != null
339 && !n.attribute().isEmpty()
340 && n.attribute().getFirst().t() instanceof OnnxModel.TensorProto tp
341 ? toTensorType(tp.dataType())
342 : OnnxType.TENSOR_FLOAT32; // default
343 default ->
344 throw new IllegalArgumentException("Could not infer op type for: " + op.toText());
345 };
346 }
347 return type;
348 }
349
350 static Object toAttributeValue(OnnxModel.AttributeProto a) {
351 return switch (a.type()) {
352 case FLOAT -> a.f();
353 case INT -> a.i();
354 case STRING -> a.s();
355 case TENSOR -> toTensor(a.t());
356 // GRAPH = 5;
357 // SPARSE_TENSOR = 11;
358 // TYPE_PROTO = 13;
359 case FLOATS -> joinFloatArray(a.floats());
360 case INTS -> joinLongArray(a.ints());
361 case STRINGS -> a.strings();
362 case TENSORS -> a.tensors().stream().map(OnnxModelTest::toTensor).toArray(Tensor[]::new);
363 // GRAPHS = 10;
364 // SPARSE_TENSORS = 12;
365 // TYPE_PROTOS = 14;
366 default -> throw new UnsupportedOperationException("Unsupported " + a.type());
367 };
368 }
369
370 static Tensor toTensor(OnnxModel.TensorProto tensorProto) {
371 // @@@ floatData, longData, stringData...
372 // @@@ externalData
373 // @@@ segments
374 return Tensor.ofShape(joinLongArray(tensorProto.dims()), tensorProto.rawData(), Tensor.ElementType.fromOnnxId(tensorProto.dataType()));
375 }
376
377 static float[] joinFloatArray(List<float[]> floats) {
378 if (floats == null) return new float[0];
379 float[] join = new float[floats.stream().mapToInt(f -> f.length).sum()];
380 int i = 0;
381 for (float[] f : floats) {
382 System.arraycopy(f, 0, join, i, f.length);
383 i += f.length;
384 }
385 return join;
386 }
387
388 static long[] joinLongArray(List<long[]> longs) {
389 if (longs == null) return new long[0];
390 long[] join = new long[longs.stream().mapToInt(f -> f.length).sum()];
391 int i = 0;
392 for (long[] f : longs) {
393 System.arraycopy(f, 0, join, i, f.length);
394 i += f.length;
395 }
396 return join;
397 }
398
399 @Test
400 public void cnnLiftTest() throws Exception {
401 try (InputStream in = CNNTest.class.getResourceAsStream("lenet-torchscript.onnx")) {
402
403 // parse onnx protobuf model
404 OnnxModel.ModelProto protoModel = OnnxModel.readFrom(in.readAllBytes());
405
406 // System.out.println(model.toText());
407
408 // lift the cnnFuncOp from Onnx protobuf model
409 OpWithNames<CoreOp.FuncOp> cnnFuncOp = toFuncOp(protoModel.graph());
410
411 System.out.println(cnnFuncOp.toText());
412 // System.out.println(cnnFuncOp.op().toText());
413
414 // test the lifted model
415 try (Arena a = Arena.ofConfined()) {
416 List<Tensor> inputValues = new ArrayList<>();
417
418 // initializers are extracted from the proto model directly
419 for (OnnxModel.TensorProto init : protoModel.graph().initializer()) {
420 inputValues.add(Tensor.ofShape(a, joinLongArray(init.dims()), init.rawData(), Tensor.ElementType.fromOnnxId(init.dataType())));
421 }
422
423 // fake image
424 float[] image = new float[28 * 28];
425 for (int i = 13; i < 28 * 28; i+=28) {
426 image[i] = 1f;
427 }
428 inputValues.add(Tensor.ofShape(a, new long[] {1, 1, 28, 28}, image));
429
430 // run
431 List<Tensor> res = OnnxRuntime.getInstance().run(a, cnnFuncOp.op().body().entryBlock(), inputValues, inputValues.size() - 1);
432
433 System.out.println(Arrays.toString(res.getFirst().data().toArray(ValueLayout.JAVA_FLOAT)));
434 }
435 }
436 }
437
438 public static void main(String[] args) throws Exception {
439 for (var fName : args) {
440 try (var in = new RandomAccessFile(fName, "r")) {
441 OnnxModel.ModelProto model = OnnxModel.readFrom(in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()));
442 System.out.println(model.toText());
443 var liftedModel = toFuncOp(model.graph());
444 System.out.println(liftedModel.toText());
445 }
446 }
447 }
448 }