1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package oracle.code.onnx.proto; 25 26 import java.io.InputStream; 27 import java.io.RandomAccessFile; 28 import java.lang.foreign.Arena; 29 import java.lang.foreign.ValueLayout; 30 import java.lang.reflect.AccessFlag; 31 import java.lang.reflect.Field; 32 import java.nio.channels.FileChannel; 33 import java.util.ArrayList; 34 import java.util.Arrays; 35 import java.util.HashMap; 36 import java.util.Iterator; 37 import java.util.LinkedHashMap; 38 import java.util.List; 39 import java.util.Map; 40 import java.util.function.Function; 41 import jdk.incubator.code.Block; 42 import jdk.incubator.code.CodeItem; 43 import jdk.incubator.code.Op; 44 import jdk.incubator.code.TypeElement; 45 import jdk.incubator.code.Value; 46 import jdk.incubator.code.extern.ExternalizedOp; 47 import jdk.incubator.code.extern.OpFactory; 48 import jdk.incubator.code.dialect.core.CoreOp; 49 import jdk.incubator.code.dialect.core.CoreType; 50 import jdk.incubator.code.dialect.core.FunctionType; 51 import jdk.incubator.code.extern.OpWriter; 52 import oracle.code.onnx.CNNTest; 53 import oracle.code.onnx.OnnxRuntime; 54 import oracle.code.onnx.Tensor; 55 import oracle.code.onnx.ir.*; 56 import org.junit.jupiter.api.Test; 57 58 59 public class OnnxModelTest { 60 61 static OnnxType toOnnxType(OnnxModel.TypeProto tp) { 62 if (tp.tensorType() instanceof OnnxModel.TypeProto.Tensor t) { 63 return toTensorType(t.elemType()); 64 } else if (tp.optionalType() instanceof OnnxModel.TypeProto.Optional o) { 65 return OnnxType.optional(toOnnxType(o.elemType())); 66 } else if (tp.sequenceType() instanceof OnnxModel.TypeProto.Sequence s) { 67 return OnnxType.seq(toOnnxType(s.elemType())); 68 } else if (tp.mapType() instanceof OnnxModel.TypeProto.Map m) { 69 return OnnxType.map(toKeyType(m.keyType()), toOnnxType(m.valueType())); 70 } else if (tp.sparseTensorType() instanceof OnnxModel.TypeProto.SparseTensor st) { 71 throw new UnsupportedOperationException("Sparse tensors not supported yet."); // @@@ 72 } 73 throw new IllegalArgumentException("No type specified."); 74 } 75 76 static FunctionType toFunctionType(OnnxModel.GraphProto g) { 77 var paramTypes = new ArrayList<TypeElement>(); 78 for (OnnxModel.ValueInfoProto input : g.input()) { 79 paramTypes.add(toOnnxType(input.type())); 80 } 81 for (OnnxModel.TensorProto init : g.initializer()) { 82 paramTypes.add(toTensorType(init.dataType())); 83 } 84 var returnType = g.output().size() == 1 85 ? toOnnxType(g.output().getFirst().type()) 86 : CoreType.tupleType(g.output().stream().map(OnnxModel.ValueInfoProto::type).map(OnnxModelTest::toOnnxType).toList()); 87 return CoreType.functionType(returnType, paramTypes); 88 } 89 90 static OnnxType toKeyType(int kt) { 91 return switch (kt) { 92 case 2 -> OnnxType.UINT8; 93 case 3 -> OnnxType.INT8; 94 case 4 -> OnnxType.UINT16; 95 case 5 -> OnnxType.INT16; 96 case 6 -> OnnxType.INT32; 97 case 7 -> OnnxType.INT64; 98 case 8 -> OnnxType.STRING; 99 case 12 -> OnnxType.UINT32; 100 case 13 -> OnnxType.UINT64; 101 default -> throw new IllegalArgumentException("Invalid key type: " + kt); 102 }; 103 } 104 105 static OnnxType.TensorType toTensorType(int tt) { 106 return switch (tt) { 107 case 1 -> OnnxType.TENSOR_FLOAT32; 108 case 2 -> OnnxType.TENSOR_UINT8; 109 case 3 -> OnnxType.TENSOR_INT8; 110 case 4 -> OnnxType.TENSOR_UINT16; 111 case 5 -> OnnxType.TENSOR_INT16; 112 case 6 -> OnnxType.TENSOR_INT32; 113 case 7 -> OnnxType.TENSOR_INT64; 114 case 8 -> OnnxType.TENSOR_STRING; 115 case 9 -> OnnxType.TENSOR_BOOL; 116 case 10 -> OnnxType.TENSOR_FLOAT16; 117 case 11 -> OnnxType.TENSOR_FLOAT64; 118 case 12 -> OnnxType.TENSOR_UINT32; 119 case 13 -> OnnxType.TENSOR_UINT64; 120 case 14 -> OnnxType.TENSOR_COMPLEX64; 121 case 15 -> OnnxType.TENSOR_COMPLEX128; 122 case 16 -> OnnxType.TENSOR_BFLOAT16; 123 case 17 -> OnnxType.TENSOR_FLOAT8E4M3FN; 124 case 18 -> OnnxType.TENSOR_FLOAT8E4M3FNUZ; 125 case 19 -> OnnxType.TENSOR_FLOAT8E5M2; 126 case 20 -> OnnxType.TENSOR_FLOAT8E5M2FNUZ; 127 case 21 -> OnnxType.TENSOR_UINT4; 128 case 22 -> OnnxType.TENSOR_INT4; 129 case 23 -> OnnxType.TENSOR_FLOAT4E2M1; 130 default -> OnnxType.tensor(null); 131 }; 132 } 133 134 static final OpFactory ONNX_OP_FACTORY = OpFactoryHelper.OP_FACTORY.get(ExplicitOnnxOps.class) 135 .andThen(OpFactoryHelper.OP_FACTORY.get(OnnxOps.class)); 136 137 static final Map<String, OnnxOp.OnnxSchema> ONNX_SCHEMA_REGISTRY = collectSchemas(ExplicitOnnxOps.class, OnnxOps.class); 138 139 private static Map<String, OnnxOp.OnnxSchema> collectSchemas(Class<?>... cls) { 140 Map<String, OnnxOp.OnnxSchema> reg = new HashMap<>(); 141 for (Class<?> c : cls) { 142 for (Class<?> nm : c.getNestMembers()) { 143 for (Field f : nm.getFields()) { 144 if (f.accessFlags().contains(AccessFlag.STATIC) && OnnxOp.OnnxSchema.class.isAssignableFrom(f.getType())) try { 145 OnnxOp.OnnxSchema sch = (OnnxOp.OnnxSchema)f.get(null); 146 reg.put(sch.name(), sch); 147 } catch (ReflectiveOperationException e) { 148 throw new RuntimeException(e); 149 } 150 } 151 } 152 } 153 return reg; 154 } 155 156 record OpWithNames<T extends Op> (T op, List<String> names) { 157 158 private Function<CodeItem, String> namer() { 159 var defNamer = OpWriter.CodeItemNamerOption.defaultValue().namer(); 160 var namer = new HashMap<Value, Integer>(); 161 return ci -> ci instanceof Value v ? names.get(namer.computeIfAbsent(v, _ -> namer.size())) : defNamer.apply(ci); 162 } 163 164 public String toText() { 165 return OpWriter.toText(op, OpWriter.CodeItemNamerOption.of(namer())); 166 } 167 } 168 169 static OpWithNames<CoreOp.FuncOp> toFuncOp(OnnxModel.GraphProto g) { 170 var valueMap = new LinkedHashMap<String, Value>(); 171 var func = CoreOp.FuncOp.func(g.name(), toFunctionType(g)).body(fb -> { 172 173 { // fill value map for parameters and initializers 174 Iterator<Block.Parameter> params = fb.entryBlock().parameters().iterator(); 175 for (OnnxModel.ValueInfoProto input : g.input()) { 176 valueMap.put(input.name(), params.next()); 177 } 178 for (OnnxModel.TensorProto init : g.initializer()) { 179 valueMap.put(init.name(), params.next()); 180 } 181 } 182 183 for (OnnxModel.NodeProto n : g.node()) { 184 String opType = n.opType(); 185 186 // @@@ an old alias ? could not find the spec 187 if (opType.equals("SimplifiedLayerNormalization")) { 188 opType = "LayerNormalization"; 189 } 190 191 if (n.domain() != null && !n.domain().isEmpty() && !n.domain().equals("ai.onnx")) { 192 opType = n.domain() + "." + opType; 193 } 194 195 OnnxOp.OnnxSchema schema = ONNX_SCHEMA_REGISTRY.computeIfAbsent(opType, ot -> {throw new IllegalArgumentException("Unknown op type: " + ot);}); 196 Map<String, Object> attributes = new LinkedHashMap<>(); 197 if (n.attribute() != null) { 198 for (OnnxModel.AttributeProto a : n.attribute()) { 199 attributes.put(a.name(), toAttributeValue(a)); 200 } 201 } 202 203 // map inputs 204 List<Value> inputs = new ArrayList<>(); 205 if (n.input() != null) { 206 List<OnnxOp.OnnxParameter> optionalInputs = new ArrayList<>(); 207 for (int i = 0; i < n.input().size(); i++) { 208 OnnxOp.OnnxParameter param = schema.inputs().get(i); 209 Value v = valueMap.get(n.input().get(i)); 210 if (v != null) { 211 switch (param.quantifier()) { 212 case REQUIRED -> { 213 inputs.add(v); 214 } 215 case OPTIONAL -> { 216 optionalInputs.add(param); 217 inputs.add(v); 218 } 219 case VARIADIC -> { 220 inputs.add(v); // @@@ accumulate variadic inputs into 221 } 222 } 223 } 224 } 225 if (!optionalInputs.isEmpty()) { 226 attributes.put("optional_inputs", optionalInputs); 227 } 228 } 229 230 // map outputs 231 List<OnnxOp.OnnxParameter> optionalOutputs = new ArrayList<>(); 232 List<String> outputNames = new ArrayList<>(); 233 if (n.output() != null) { 234 for (int i = 0; i < n.output().size(); i++) { 235 OnnxOp.OnnxParameter param = schema.outputs().get(i); 236 if (!n.output().get(i).isEmpty()) { 237 outputNames.add(n.output().get(i)); 238 if (param.quantifier() == OnnxOp.OnnxParameter.Quantifier.OPTIONAL) { 239 optionalOutputs.add(param); 240 } 241 } 242 } 243 if (!optionalOutputs.isEmpty()) { 244 attributes.put("optional_outputs", optionalOutputs); 245 } 246 } 247 248 // inline Constant op tensor attribute as value 249 if (opType.equals("Constant") && attributes.remove(OnnxOps.Constant.Attribute.value.name()) instanceof Tensor t) { 250 switch (t.shape().length) { 251 case 0 -> { // scalar 252 switch (t.elementType()) { 253 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_float.name(), t.data().get(ValueLayout.JAVA_FLOAT, 0)); 254 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_int.name(), t.data().get(ValueLayout.JAVA_LONG, 0)); 255 default -> throw new UnsupportedOperationException(); 256 } 257 } 258 case 1 -> { // 1d tensor 259 switch (t.elementType()) { 260 case FLOAT -> attributes.put(OnnxOps.Constant.Attribute.value_floats.name(), t.data().toArray(ValueLayout.JAVA_FLOAT)); 261 case INT64 -> attributes.put(OnnxOps.Constant.Attribute.value_ints.name(), t.data().toArray(ValueLayout.JAVA_LONG)); 262 default -> throw new UnsupportedOperationException(); 263 } 264 } 265 default -> throw new UnsupportedOperationException(); 266 } 267 } 268 269 // get the op 270 ExternalizedOp extOp = new ExternalizedOp( 271 opType, 272 null, 273 inputs, 274 List.of(), 275 new OnnxType.TensorType(null), 276 attributes, 277 List.of()); 278 OnnxOp rawOp = (OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp); 279 280 // patch the op return type 281 TypeElement returnType = rawOp.onnxOutputs().size() == 1 282 ? inferTypeVariableType(rawOp.onnxOutputs().getFirst().type(), rawOp, n) 283 : CoreType.tupleType(rawOp.onnxOutputs().stream().map(o -> inferTypeVariableType(o.type(), rawOp, n)).toList()); 284 extOp = new ExternalizedOp( 285 extOp.name(), 286 null, 287 extOp.operands(), 288 extOp.successors(), 289 returnType, 290 extOp.attributes(), 291 extOp.bodyDefinitions()); 292 Op.Result res = fb.op((OnnxOp)ONNX_OP_FACTORY.constructOpOrFail(extOp)); 293 294 // map outputs 295 if (outputNames.size() == 1) { 296 valueMap.put(n.output().getFirst(), res); 297 } else { 298 valueMap.put(n.name(), res); 299 for (int i = 0; i < outputNames.size(); i++) { 300 valueMap.put(outputNames.get(i), fb.op(CoreOp.tupleLoad(res, i))); 301 } 302 } 303 } 304 305 if (g.output().size() == 1) { 306 fb.op(CoreOp.return_(valueMap.get(g.output().getFirst().name()))); 307 } else { 308 Op.Result ret = fb.op(CoreOp.tuple(g.output().stream().map(OnnxModel.ValueInfoProto::name).map(valueMap::get).toList())); 309 valueMap.put(g.name() + "_return", ret); 310 fb.op(CoreOp.return_(ret)); 311 } 312 }); 313 314 return new OpWithNames<>(func, List.of(valueMap.sequencedKeySet().toArray(String[]::new))); 315 } 316 317 static OnnxType inferTypeVariableType(OnnxType type, OnnxOp op, OnnxModel.NodeProto n) { 318 if (type instanceof OnnxType.TypeVariable tv) { 319 if (tv.types().size() == 1) { 320 return tv.types().getFirst(); 321 } 322 // search for the same type variable across inputs 323 for (var ie : op.onnxInputs().entrySet()) { 324 if (ie.getKey().type().equals(tv)) { 325 if (ie.getValue() instanceof Value v && v.type() instanceof OnnxType ot) { 326 return ot; 327 } else if (ie.getValue() instanceof List l && !l.isEmpty() && l.getFirst() instanceof Value v && v.type() instanceof OnnxType ot) { 328 return ot; 329 } 330 } 331 } 332 333 // special cases 334 return switch (op) { 335 case OnnxOps.Cast c -> 336 toTensorType((int)c.to()); 337 case OnnxOps.ConstantOfShape _, OnnxOps.Constant _-> // get tensor type from tensor attribute 338 n.attribute() != null 339 && !n.attribute().isEmpty() 340 && n.attribute().getFirst().t() instanceof OnnxModel.TensorProto tp 341 ? toTensorType(tp.dataType()) 342 : OnnxType.TENSOR_FLOAT32; // default 343 default -> 344 throw new IllegalArgumentException("Could not infer op type for: " + op.toText()); 345 }; 346 } 347 return type; 348 } 349 350 static Object toAttributeValue(OnnxModel.AttributeProto a) { 351 return switch (a.type()) { 352 case FLOAT -> a.f(); 353 case INT -> a.i(); 354 case STRING -> a.s(); 355 case TENSOR -> toTensor(a.t()); 356 // GRAPH = 5; 357 // SPARSE_TENSOR = 11; 358 // TYPE_PROTO = 13; 359 case FLOATS -> joinFloatArray(a.floats()); 360 case INTS -> joinLongArray(a.ints()); 361 case STRINGS -> a.strings(); 362 case TENSORS -> a.tensors().stream().map(OnnxModelTest::toTensor).toArray(Tensor[]::new); 363 // GRAPHS = 10; 364 // SPARSE_TENSORS = 12; 365 // TYPE_PROTOS = 14; 366 default -> throw new UnsupportedOperationException("Unsupported " + a.type()); 367 }; 368 } 369 370 static Tensor toTensor(OnnxModel.TensorProto tensorProto) { 371 // @@@ floatData, longData, stringData... 372 // @@@ externalData 373 // @@@ segments 374 return Tensor.ofShape(joinLongArray(tensorProto.dims()), tensorProto.rawData(), Tensor.ElementType.fromOnnxId(tensorProto.dataType())); 375 } 376 377 static float[] joinFloatArray(List<float[]> floats) { 378 if (floats == null) return new float[0]; 379 float[] join = new float[floats.stream().mapToInt(f -> f.length).sum()]; 380 int i = 0; 381 for (float[] f : floats) { 382 System.arraycopy(f, 0, join, i, f.length); 383 i += f.length; 384 } 385 return join; 386 } 387 388 static long[] joinLongArray(List<long[]> longs) { 389 if (longs == null) return new long[0]; 390 long[] join = new long[longs.stream().mapToInt(f -> f.length).sum()]; 391 int i = 0; 392 for (long[] f : longs) { 393 System.arraycopy(f, 0, join, i, f.length); 394 i += f.length; 395 } 396 return join; 397 } 398 399 @Test 400 public void cnnLiftTest() throws Exception { 401 try (InputStream in = CNNTest.class.getResourceAsStream("lenet-torchscript.onnx")) { 402 403 // parse onnx protobuf model 404 OnnxModel.ModelProto protoModel = OnnxModel.readFrom(in.readAllBytes()); 405 406 // System.out.println(model.toText()); 407 408 // lift the cnnFuncOp from Onnx protobuf model 409 OpWithNames<CoreOp.FuncOp> cnnFuncOp = toFuncOp(protoModel.graph()); 410 411 System.out.println(cnnFuncOp.toText()); 412 // System.out.println(cnnFuncOp.op().toText()); 413 414 // test the lifted model 415 try (Arena a = Arena.ofConfined()) { 416 List<Tensor> inputValues = new ArrayList<>(); 417 418 // initializers are extracted from the proto model directly 419 for (OnnxModel.TensorProto init : protoModel.graph().initializer()) { 420 inputValues.add(Tensor.ofShape(a, joinLongArray(init.dims()), init.rawData(), Tensor.ElementType.fromOnnxId(init.dataType()))); 421 } 422 423 // fake image 424 float[] image = new float[28 * 28]; 425 for (int i = 13; i < 28 * 28; i+=28) { 426 image[i] = 1f; 427 } 428 inputValues.add(Tensor.ofShape(a, new long[] {1, 1, 28, 28}, image)); 429 430 // run 431 List<Tensor> res = OnnxRuntime.getInstance().run(a, cnnFuncOp.op().body().entryBlock(), inputValues, inputValues.size() - 1); 432 433 System.out.println(Arrays.toString(res.getFirst().data().toArray(ValueLayout.JAVA_FLOAT))); 434 } 435 } 436 } 437 438 public static void main(String[] args) throws Exception { 439 for (var fName : args) { 440 try (var in = new RandomAccessFile(fName, "r")) { 441 OnnxModel.ModelProto model = OnnxModel.readFrom(in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length())); 442 System.out.println(model.toText()); 443 var liftedModel = toFuncOp(model.graph()); 444 System.out.println(liftedModel.toText()); 445 } 446 } 447 } 448 }