1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.onnx; 27 28 import java.io.ByteArrayOutputStream; 29 import java.lang.foreign.ValueLayout; 30 import java.nio.charset.StandardCharsets; 31 import java.util.ArrayList; 32 import java.util.HashMap; 33 import java.util.List; 34 import java.util.function.BiConsumer; 35 import java.util.function.Function; 36 import java.util.stream.IntStream; 37 import jdk.incubator.code.Block; 38 import jdk.incubator.code.CodeItem; 39 import jdk.incubator.code.Op; 40 import jdk.incubator.code.Value; 41 import jdk.incubator.code.op.CoreOp; 42 import jdk.incubator.code.type.JavaType; 43 import jdk.incubator.code.type.TupleType; 44 import jdk.incubator.code.writer.OpWriter; 45 import oracle.code.onnx.ir.OnnxOp; 46 import oracle.code.onnx.ir.OnnxOps; 47 import oracle.code.onnx.ir.OnnxType; 48 49 // Generated from onnx.proto3 50 sealed class OnnxProtoBuilder<T extends OnnxProtoBuilder> { 51 52 static final class Attribute extends OnnxProtoBuilder<Attribute> { 53 Attribute name(String name) {return _f(1, name);} 54 Attribute ref_attr_name(String ref_attr_name) {return _f(21, ref_attr_name);} 55 Attribute doc_string(String doc_string) {return _f(13, doc_string);} 56 Attribute type(int type) {return _f(20, type);} 57 Attribute f(float f) {return _f(2, f);} 58 Attribute i(long i) {return _f(3, i);} 59 Attribute s(byte[] s) {return _f(4, s);} 60 Attribute t(TensorProto t) {return _f(5, t);} 61 Attribute g(GraphProto g) {return _f(6, g);} 62 Attribute sparse_tensor(SparseTensorProto sparse_tensor) {return _f(22, sparse_tensor);} 63 Attribute tp(TypeProto tp) {return _f(14, tp);} 64 Attribute floats(float... floats) {return _f(7, floats);} 65 Attribute ints(long... ints) {return _f(8, ints);} 66 Attribute strings(byte[] strings) {return _f(9, strings);} 67 Attribute tensors(TensorProto tensors) {return _f(10, tensors);} 68 Attribute graphs(GraphProto graphs) {return _f(11, graphs);} 69 Attribute sparse_tensors(SparseTensorProto sparse_tensors) {return _f(23, sparse_tensors);} 70 Attribute type_protos(TypeProto type_protos) {return _f(15, type_protos);} 71 } 72 73 static final class ValueInfoProto extends OnnxProtoBuilder<ValueInfoProto> { 74 ValueInfoProto name(String name) {return _f(1, name);} 75 ValueInfoProto type(TypeProto type) {return _f(2, type);} 76 ValueInfoProto doc_string(String doc_string) {return _f(3, doc_string);} 77 ValueInfoProto metadata_props(StringStringEntryProto metadata_props) {return _f(4, metadata_props);} 78 } 79 80 static final class NodeProto extends OnnxProtoBuilder<NodeProto> { 81 NodeProto input(String input) {return _f(1, input);} 82 NodeProto output(String output) {return _f(2, output);} 83 NodeProto name(String name) {return _f(3, name);} 84 NodeProto op_type(String op_type) {return _f(4, op_type);} 85 NodeProto domain(String domain) {return _f(7, domain);} 86 NodeProto overload(String overload) {return _f(8, overload);} 87 NodeProto attribute(Attribute attribute) {return _f(5, attribute);} 88 NodeProto doc_string(String doc_string) {return _f(6, doc_string);} 89 NodeProto metadata_props(StringStringEntryProto metadata_props) {return _f(9, metadata_props);} 90 } 91 92 static final class TrainingInfoProto extends OnnxProtoBuilder<TrainingInfoProto> { 93 TrainingInfoProto initialization(GraphProto initialization) {return _f(1, initialization);} 94 TrainingInfoProto algorithm(GraphProto algorithm) {return _f(2, algorithm);} 95 TrainingInfoProto initialization_binding(StringStringEntryProto initialization_binding) {return _f(3, initialization_binding);} 96 TrainingInfoProto update_binding(StringStringEntryProto update_binding) {return _f(4, update_binding);} 97 } 98 99 static final class ModelProto extends OnnxProtoBuilder<ModelProto> { 100 ModelProto ir_version(long ir_version) {return _f(1, ir_version);} 101 ModelProto opset_import(OperatorSetIdProto opset_import) {return _f(8, opset_import);} 102 ModelProto producer_name(String producer_name) {return _f(2, producer_name);} 103 ModelProto producer_version(String producer_version) {return _f(3, producer_version);} 104 ModelProto domain(String domain) {return _f(4, domain);} 105 ModelProto model_version(long model_version) {return _f(5, model_version);} 106 ModelProto doc_string(String doc_string) {return _f(6, doc_string);} 107 ModelProto graph(GraphProto graph) {return _f(7, graph);} 108 ModelProto metadata_props(StringStringEntryProto metadata_props) {return _f(14, metadata_props);} 109 ModelProto training_info(TrainingInfoProto training_info) {return _f(20, training_info);} 110 ModelProto functions(FunctionProto functions) {return _f(25, functions);} 111 } 112 113 static final class StringStringEntryProto extends OnnxProtoBuilder<StringStringEntryProto> { 114 StringStringEntryProto key(String key) {return _f(1, key);} 115 StringStringEntryProto value(String value) {return _f(2, value);} 116 } 117 118 static final class TensorAnnotation extends OnnxProtoBuilder<TensorAnnotation> { 119 TensorAnnotation tensor_name(String tensor_name) {return _f(1, tensor_name);} 120 TensorAnnotation quant_parameter_tensor_names(StringStringEntryProto quant_parameter_tensor_names) {return _f(2, quant_parameter_tensor_names);} 121 } 122 123 static final class GraphProto extends OnnxProtoBuilder<GraphProto> { 124 GraphProto node(NodeProto node) {return _f(1, node);} 125 GraphProto name(String name) {return _f(2, name);} 126 GraphProto initializer(TensorProto initializer) {return _f(5, initializer);} 127 GraphProto sparse_initializer(SparseTensorProto sparse_initializer) {return _f(15, sparse_initializer);} 128 GraphProto doc_string(String doc_string) {return _f(10, doc_string);} 129 GraphProto input(ValueInfoProto input) {return _f(11, input);} 130 GraphProto output(ValueInfoProto output) {return _f(12, output);} 131 GraphProto value_info(ValueInfoProto value_info) {return _f(13, value_info);} 132 GraphProto quantization_annotation(TensorAnnotation quantization_annotation) {return _f(14, quantization_annotation);} 133 GraphProto metadata_props(StringStringEntryProto metadata_props) {return _f(16, metadata_props);} 134 } 135 136 static final class TensorProto extends OnnxProtoBuilder<TensorProto> { 137 TensorProto dims(long... dims) {return _f(1, dims);} 138 TensorProto data_type(int data_type) {return _f(2, data_type);} 139 TensorProto segment(Segment segment) {return _f(3, segment);} 140 TensorProto float_data(float... float_data) {return _f(4, float_data);} 141 TensorProto int32_data(int... int32_data) {return _f(5, int32_data);} 142 TensorProto string_data(byte[] string_data) {return _f(6, string_data);} 143 TensorProto int64_data(long... int64_data) {return _f(7, int64_data);} 144 TensorProto name(String name) {return _f(8, name);} 145 TensorProto doc_string(String doc_string) {return _f(12, doc_string);} 146 TensorProto raw_data(byte[] raw_data) {return _f(9, raw_data);} 147 TensorProto external_data(StringStringEntryProto external_data) {return _f(13, external_data);} 148 TensorProto data_location(int data_location) {return _f(14, data_location);} 149 TensorProto double_data(double... double_data) {return _f(10, double_data);} 150 TensorProto uint64_data(long... uint64_data) {return _f(11, uint64_data);} 151 TensorProto metadata_props(StringStringEntryProto metadata_props) {return _f(16, metadata_props);} 152 } 153 154 static final class Segment extends OnnxProtoBuilder<Segment> { 155 Segment begin(long begin) {return _f(1, begin);} 156 Segment end(long end) {return _f(2, end);} 157 } 158 159 static final class SparseTensorProto extends OnnxProtoBuilder<SparseTensorProto> { 160 SparseTensorProto values(TensorProto values) {return _f(1, values);} 161 SparseTensorProto indices(TensorProto indices) {return _f(2, indices);} 162 SparseTensorProto dims(long... dims) {return _f(3, dims);} 163 } 164 165 static final class TensorShapeProto extends OnnxProtoBuilder<TensorShapeProto> { 166 TensorShapeProto dim(Dimension dim) {return _f(1, dim);} 167 } 168 169 static final class Dimension extends OnnxProtoBuilder<Dimension> { 170 Dimension dim_value(long dim_value) {return _f(1, dim_value);} 171 Dimension dim_param(String dim_param) {return _f(2, dim_param);} 172 Dimension denotation(String denotation) {return _f(3, denotation);} 173 } 174 175 static final class TypeProto extends OnnxProtoBuilder<TypeProto> { 176 TypeProto tensor_type(Tensor tensor_type) {return _f(1, tensor_type);} 177 TypeProto sequence_type(Sequence sequence_type) {return _f(4, sequence_type);} 178 TypeProto map_type(Map map_type) {return _f(5, map_type);} 179 TypeProto optional_type(Optional optional_type) {return _f(9, optional_type);} 180 TypeProto sparse_tensor_type(SparseTensor sparse_tensor_type) {return _f(8, sparse_tensor_type);} 181 TypeProto denotation(String denotation) {return _f(6, denotation);} 182 } 183 184 static final class Tensor extends OnnxProtoBuilder<Tensor> { 185 Tensor elem_type(int elem_type) {return _f(1, elem_type);} 186 Tensor shape(TensorShapeProto shape) {return _f(2, shape);} 187 } 188 189 static final class Sequence extends OnnxProtoBuilder<Sequence> { 190 Sequence elem_type(TypeProto elem_type) {return _f(1, elem_type);} 191 } 192 193 static final class Map extends OnnxProtoBuilder<Map> { 194 Map key_type(int key_type) {return _f(1, key_type);} 195 Map value_type(TypeProto value_type) {return _f(2, value_type);} 196 } 197 198 static final class Optional extends OnnxProtoBuilder<Optional> { 199 Optional elem_type(TypeProto elem_type) {return _f(1, elem_type);} 200 } 201 202 static final class SparseTensor extends OnnxProtoBuilder<SparseTensor> { 203 SparseTensor elem_type(int elem_type) {return _f(1, elem_type);} 204 SparseTensor shape(TensorShapeProto shape) {return _f(2, shape);} 205 } 206 207 static final class OperatorSetIdProto extends OnnxProtoBuilder<OperatorSetIdProto> { 208 OperatorSetIdProto domain(String domain) {return _f(1, domain);} 209 OperatorSetIdProto version(long version) {return _f(2, version);} 210 } 211 212 static final class FunctionProto extends OnnxProtoBuilder<FunctionProto> { 213 FunctionProto name(String name) {return _f(1, name);} 214 FunctionProto input(String input) {return _f(4, input);} 215 FunctionProto output(String output) {return _f(5, output);} 216 FunctionProto attribute(String attribute) {return _f(6, attribute);} 217 FunctionProto attribute_proto(Attribute attribute_proto) {return _f(11, attribute_proto);} 218 FunctionProto node(NodeProto node) {return _f(7, node);} 219 FunctionProto doc_string(String doc_string) {return _f(8, doc_string);} 220 FunctionProto opset_import(OperatorSetIdProto opset_import) {return _f(9, opset_import);} 221 FunctionProto domain(String domain) {return _f(10, domain);} 222 FunctionProto overload(String overload) {return _f(13, overload);} 223 FunctionProto value_info(ValueInfoProto value_info) {return _f(12, value_info);} 224 FunctionProto metadata_props(StringStringEntryProto metadata_props) {return _f(14, metadata_props);} 225 } 226 227 final ByteArrayOutputStream buf = new ByteArrayOutputStream(); 228 229 void _encode(long number) { 230 for (int i = 64 - Long.numberOfLeadingZeros(number); i > 7; i -= 7) { 231 buf.write(0x80 | (int)number & 0x7f); 232 number >>= 7; 233 } 234 buf.write((int)number & 0x7f); 235 } 236 237 void _encode(float value) { 238 int bits = Float.floatToRawIntBits(value); 239 buf.write((byte)bits); 240 buf.write((byte)(bits >> 8)); 241 buf.write((byte)(bits >> 16)); 242 buf.write((byte)(bits >> 24)); 243 } 244 245 void _encode(double value) { 246 long bits = Double.doubleToRawLongBits(value); 247 buf.write((byte)bits); 248 buf.write((byte)(bits >> 8)); 249 buf.write((byte)(bits >> 16)); 250 buf.write((byte)(bits >> 24)); 251 buf.write((byte)(bits >> 32)); 252 buf.write((byte)(bits >> 40)); 253 buf.write((byte)(bits >> 48)); 254 buf.write((byte)(bits >> 56)); 255 } 256 257 @SuppressWarnings("unchecked") 258 T _f(int fieldIndex, String value) { 259 return value == null ? (T)this : _f(fieldIndex, value.getBytes(StandardCharsets.UTF_8)); 260 } 261 262 @SuppressWarnings("unchecked") 263 T _f(int fieldIndex, byte[] bytes) { 264 _encode(fieldIndex << 3 | 2); 265 _encode(bytes.length); 266 buf.writeBytes(bytes); 267 return (T)this; 268 } 269 270 @SuppressWarnings("unchecked") 271 T _f(int fieldIndex, float value) { 272 _encode(fieldIndex << 3 | 5); 273 _encode(value); 274 return (T)this; 275 } 276 277 @SuppressWarnings("unchecked") 278 T _f(int fieldIndex, float... values) { 279 if (values.length == 1) { 280 return _f(fieldIndex, values[0]); 281 } 282 var b = new OnnxProtoBuilder(); 283 for (var v : values) b._encode(v); 284 _f(fieldIndex, b); 285 return (T)this; 286 } 287 288 @SuppressWarnings("unchecked") 289 T _f(int fieldIndex, double value) { 290 _encode(fieldIndex << 3 | 1); 291 _encode(value); 292 return (T)this; 293 } 294 295 @SuppressWarnings("unchecked") 296 T _f(int fieldIndex, double... values) { 297 if (values.length == 1) { 298 return _f(fieldIndex, values[0]); 299 } 300 var b = new OnnxProtoBuilder(); 301 for (var v : values) b._encode(v); 302 _f(fieldIndex, b); 303 return (T)this; 304 } 305 306 @SuppressWarnings("unchecked") 307 T _f(int fieldIndex, long value) { 308 _encode(fieldIndex << 3); 309 _encode(value); 310 return (T)this; 311 } 312 313 @SuppressWarnings("unchecked") 314 T _f(int fieldIndex, long... values) { 315 if (values.length == 1) { 316 return _f(fieldIndex, values[0]); 317 } 318 var b = new OnnxProtoBuilder(); 319 for (var v : values) b._encode(v); 320 _f(fieldIndex, b); 321 return (T)this; 322 } 323 324 @SuppressWarnings("unchecked") 325 T _f(int fieldIndex, int... values) { 326 if (values.length == 1) { 327 return _f(fieldIndex, values[0]); 328 } 329 var b = new OnnxProtoBuilder(); 330 for (var v : values) b._encode(v); 331 _f(fieldIndex, b); 332 return (T)this; 333 } 334 335 @SuppressWarnings("unchecked") T _f(int fieldIndex, OnnxProtoBuilder value) { 336 return _f(fieldIndex, value.buf.toByteArray()); 337 } 338 339 @SuppressWarnings("unchecked") 340 <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) { 341 sup.forEach(p -> cons.accept((T)this, p)); 342 return (T)this; 343 } 344 345 static final int IR_VERSION = 10; 346 static final int OPSET_VERSION = 21; 347 348 private static final class Indexer { 349 350 private final Function<CodeItem, String> baseNames; 351 private final HashMap<String, String> elementsMap; 352 353 354 Indexer(Function<CodeItem, String> baseNames) { 355 this.baseNames = baseNames; 356 this.elementsMap = new HashMap<>(); 357 } 358 359 private String baseName(Value value, int elementIndex) { 360 var name = "%" + baseNames.apply(value); 361 return elementIndex > 0 ? name + '.' + elementIndex : name; 362 } 363 364 String nameOf(Value value) { 365 return nameOf(value, 0); 366 } 367 368 String nameOf(Value tuple, int elementIndex) { 369 var name = baseName(tuple, elementIndex); 370 return elementsMap.getOrDefault(name, name); 371 } 372 373 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) { 374 elementsMap.put(baseName(tupleLoadResult, 0), nameOf(tuple, elementIndex)); 375 } 376 377 void mapTupleElements(Value tuple, List<Value> elements) { 378 for (int i = 0; i < elements.size(); i++) { 379 elementsMap.put(baseName(tuple, i), nameOf(elements.get(i))); 380 } 381 } 382 } 383 384 static byte[] build(String domainName, CoreOp.ModuleOp module, List<oracle.code.onnx.Tensor> initializers) { 385 var indexer = new Indexer(OpWriter.computeGlobalNames(module)); 386 387 var functions = new ArrayList<>(module.functionTable().sequencedValues()); 388 var mainFunc = functions.removeLast(); 389 var mainBlock = mainFunc.body().entryBlock(); 390 391 var model = build( 392 graph(mainFunc.funcName(), domainName, indexer, mainBlock, initializers, 0), 393 List.of(domainName), 394 functions.stream().map(f -> 395 function(domainName, 396 f.funcName(), 397 f.parameters().stream().map(indexer::nameOf).toList(), 398 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()), 399 nodes(domainName, indexer, f.body().entryBlock().ops()))).toList()); 400 401 // OnnxProtoPrinter.printModel(model); 402 return model; 403 } 404 405 // @@@ unchecked constraints: 406 // tensor FuncOp parameters and single tensor return type 407 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor) 408 // entry block only 409 static byte[] build(Block block, List<oracle.code.onnx.Tensor> initializers) { 410 var indexer = new Indexer(OpWriter.computeGlobalNames(block.parentBody().parentOp())); 411 var model = build(graph(null, null, indexer, block, initializers, 0), List.of(), List.of()); 412 // OnnxProtoPrinter.printModel(model); 413 return model; 414 } 415 416 static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 417 return build(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of()); 418 } 419 420 static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) { 421 return build(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions); 422 } 423 424 static byte[] build(GraphProto graph, List<String> customImportDomains, List<FunctionProto> functions) { 425 return new ModelProto() 426 .ir_version(IR_VERSION) 427 .opset_import(new OperatorSetIdProto().version(OPSET_VERSION)) 428 .forEach(customImportDomains, (m, d) -> m.opset_import(new OperatorSetIdProto().domain(d))) 429 .forEach(functions, (m, f) -> m.functions(f)) 430 .graph(graph) 431 .buf.toByteArray(); 432 } 433 434 static List<String> expandTuples(Indexer indexer, List<Value> values) { 435 var names = new ArrayList<String>(); 436 expandTuples(indexer, names, values); 437 return names; 438 } 439 440 static void expandTuples(Indexer indexer, List<String> names, List<Value> values) { 441 for (var v : values) { 442 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) { 443 expandTuples(indexer, names, op.operands()); 444 } else if (v.type() instanceof TupleType tt) { 445 var ct = tt.componentTypes(); 446 for (int i = 0; i < ct.size(); i++) { 447 names.add(indexer.nameOf(v, i)); 448 } 449 } else { 450 names.add(indexer.nameOf(v)); 451 } 452 } 453 } 454 455 static GraphProto graph(String graphName, String domainName, Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers, int scalarArgs) { 456 var params = block.parameters(); 457 params.forEach(indexer::nameOf); 458 int firstInitializer = params.size() - initializers.size(); 459 var args = params.subList(0, firstInitializer); 460 return graph(graphName, 461 IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.nameOf(params.get(i + firstInitializer)), initializers.get(i))).toList(), 462 tensorInfos(indexer, args, scalarArgs), 463 nodes(domainName, indexer, block.ops()), 464 expandTuples(indexer, block.terminatingOp().operands())); 465 } 466 467 static List<NodeProto> nodes(String domainName, Indexer indexer, List<Op> ops) { 468 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> { 469 switch (op) { 470 case OnnxOps.If ifOp -> 471 opNodes.accept(node( 472 ifOp.opName(), 473 List.of(indexer.nameOf(ifOp.operands().getFirst())), 474 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(), 475 java.util.Map.of( 476 "then_branch", graph(null, domainName, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0), 477 "else_branch", graph(null, domainName, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0)))); 478 case OnnxOps.Loop loopOp -> { 479 opNodes.accept(node(loopOp.opName(), 480 expandTuples(indexer, loopOp.operands()), 481 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(), 482 java.util.Map.of( 483 "body", graph(null, domainName, indexer, loopOp.loopBody().entryBlock(), List.of(), 2)))); 484 } 485 case OnnxOp onnxOp -> 486 opNodes.accept(node( 487 onnxOp.opName(), 488 onnxOp.operands().stream().map(indexer::nameOf).toList(), 489 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(), 490 onnxOp.onnxAttributes())); 491 case CoreOp.FuncCallOp fco -> 492 opNodes.accept(node( 493 domainName, 494 fco.funcName(), 495 fco.operands().stream().map(indexer::nameOf).toList(), 496 expandTuples(indexer, List.of(fco.result())), 497 java.util.Map.of())); 498 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip 499 } 500 case CoreOp.TupleLoadOp tlo -> 501 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index()); 502 case CoreOp.TupleOp to -> 503 indexer.mapTupleElements(to.result(), to.operands()); 504 case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> { 505 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) { 506 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i); 507 } else if (io.invokeDescriptor().name().equals("of")) { 508 indexer.mapTupleElements(io.result(), io.operands()); 509 } else { 510 throw new UnsupportedOperationException(op.toText()); 511 } 512 } 513 default -> { 514 throw new UnsupportedOperationException(op.toText()); 515 } 516 } 517 }).toList(); 518 } 519 520 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) { 521 var infos = new ArrayList<ValueInfoProto>(); 522 for (var arg : args) { 523 switch (arg.type()) { 524 case OnnxType.TensorType tt -> 525 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs)); 526 case TupleType tt -> { 527 var ct = tt.componentTypes(); 528 for (int i = 0; i < ct.size(); i++) { 529 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs)); 530 } 531 } 532 default -> 533 throw new UnsupportedOperationException(arg.type().toString()); 534 } 535 } 536 return infos; 537 } 538 539 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 540 return new GraphProto() 541 .name(name) 542 .forEach(initializers, (g, i) -> g.initializer(i)) 543 .forEach(inputs, (g, i) -> g.input(i)) 544 .forEach(ops, (g, op) -> g.node(op)) 545 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))); 546 } 547 548 static FunctionProto function(String domain, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) { 549 return new FunctionProto() 550 .domain(domain) 551 .name(functionName) 552 .forEach(inputNames, (f, i) -> f.input(i)) 553 .forEach(ops, (g, op) -> g.node(op)) 554 .forEach(outputNames, (f, o) -> f.output(o)) 555 .opset_import(new OperatorSetIdProto().version(OPSET_VERSION)); 556 } 557 558 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 559 return new NodeProto() 560 .domain(domain) 561 .op_type(opName) 562 .forEach(inputNames, (n, iName) -> n.input(iName)) 563 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 564 .forEach(outputNames, (n, oName) -> n.output(oName)); 565 } 566 567 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 568 return new NodeProto() 569 .op_type(opName) 570 .forEach(inputNames, (n, iName) -> n.input(iName)) 571 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 572 .forEach(outputNames, (n, oName) -> n.output(oName)); 573 } 574 575 static ValueInfoProto tensorInfo(String name, int tensorElementType) { 576 return tensorInfo(name, tensorElementType, false); 577 } 578 579 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) { 580 var t = new Tensor().elem_type(tensorElementType); 581 if (addScalarShape) t.shape(new TensorShapeProto()); 582 return new ValueInfoProto() 583 .name(name) 584 .type(new TypeProto().tensor_type(t)); 585 } 586 587 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor) { 588 return new TensorProto() 589 .name(name) 590 .data_type(tensor.elementType().id) 591 .dims(tensor.shape()) 592 .raw_data(tensor.data().toArray(ValueLayout.JAVA_BYTE)); 593 } 594 595 static Attribute attribute(String name, Object value) { 596 var attr = new Attribute().name(name); 597 switch (value) { 598 case Float f -> { 599 attr.type(1).f(f); 600 } 601 case Long l -> { 602 attr.type(2).i(l); 603 } 604 case GraphProto g -> { 605 attr.type(5).g(g.name(name)); 606 } 607 case float[] floats -> { 608 attr.type(6); 609 attr.floats(floats); 610 } 611 case long[] longs -> { 612 attr.type(7); 613 attr.ints(longs); 614 } 615 default -> { 616 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo 617 } 618 } 619 return attr; 620 } 621 }