1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.onnx; 27 28 import java.lang.foreign.ValueLayout; 29 import java.util.ArrayList; 30 import java.util.HashMap; 31 import java.util.List; 32 import java.util.Map; 33 import java.util.Optional; 34 import java.util.SequencedMap; 35 import java.util.function.Function; 36 import java.util.stream.IntStream; 37 import jdk.incubator.code.Block; 38 import jdk.incubator.code.CodeItem; 39 import jdk.incubator.code.Op; 40 import jdk.incubator.code.Value; 41 import jdk.incubator.code.dialect.core.CoreOp; 42 import jdk.incubator.code.dialect.core.TupleType; 43 import jdk.incubator.code.dialect.java.JavaOp; 44 import jdk.incubator.code.dialect.java.JavaType; 45 import jdk.incubator.code.extern.OpWriter; 46 import oracle.code.onnx.ir.OnnxOp; 47 import oracle.code.onnx.ir.OnnxOps; 48 import oracle.code.onnx.ir.OnnxType; 49 import oracle.code.onnx.proto.OnnxBuilder.*; 50 import oracle.code.onnx.proto.OnnxConstants.*; 51 52 public final class OnnxProtoBuilder { 53 54 static final int IR_VERSION = 10; 55 static final int OPSET_VERSION = 21; 56 57 private static final class Indexer { 58 59 private final Function<CodeItem, String> baseNames; 60 private final HashMap<String, String> remap; 61 62 63 Indexer(Op root, Map<Value, String> explicitNames) { 64 this.baseNames = OpWriter.computeGlobalNames(root); 65 this.remap = new HashMap<>(); 66 explicitNames.forEach(this::setName); 67 } 68 69 void setName(Value val, String name) { 70 switch (val) { 71 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> { 72 remap.put(baseName(val), name); 73 for (int i = 0; i < to.operands().size(); i++) { 74 setName(to.operands().get(i), name + "." + i); 75 } 76 } 77 case Block.Parameter bp when val.type() instanceof TupleType tt -> { 78 for (int i = 0; i < tt.componentTypes().size(); i++) { 79 remap.put(baseName(val, i), name +"." + i); 80 } 81 } 82 default -> { 83 remap.put(baseName(val), name); 84 if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) { 85 Value tr = tlo.operands().getFirst(); 86 remap.put(baseName(tr, tlo.index()), name); 87 if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) { 88 setName(to.operands().get(tlo.index()), name); 89 } 90 } 91 } 92 } 93 } 94 95 private String baseName(Value value) { 96 return "%" + baseNames.apply(value); 97 } 98 99 private String baseName(Value value, int elementIndex) { 100 var name = baseName(value); 101 return elementIndex > 0 ? name + '.' + elementIndex : name; 102 } 103 104 String nameOf(Value value) { 105 var name = baseName(value); 106 return remap.getOrDefault(name, name); 107 } 108 109 String nameOf(Value tuple, int elementIndex) { 110 var name = baseName(tuple, elementIndex); 111 return remap.getOrDefault(name, name); 112 } 113 114 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) { 115 remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex)); 116 } 117 118 void mapTupleElements(Value tuple, List<Value> elements) { 119 for (int i = 0; i < elements.size(); i++) { 120 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i))); 121 } 122 } 123 } 124 125 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) { 126 return buildModel(domain, module, initializers, Map.of(), _ -> null); 127 } 128 129 public record ExternalTensorDataInfo(String location, long offset, long length) { 130 } 131 132 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 133 var indexer = new Indexer(module, explicitValueNames); 134 135 var functions = new ArrayList<>(module.functionTable().sequencedValues()); 136 var imports = new ArrayList<String>(); 137 if (functions.size() > 1) imports.add(domain); // self domain import if additional functions 138 for (var f : functions) { 139 for (var op : f.body().entryBlock().ops()) { // auto import of op domains 140 if (op instanceof OnnxOp oop) { 141 String name = oop.schema().name(); 142 int di = name.lastIndexOf('.'); 143 if (di > 0) { 144 String dn = name.substring(0, di); 145 if (!imports.contains(dn)) imports.add(dn); 146 } 147 } 148 } 149 } 150 var mainFunc = functions.removeLast(); 151 var mainBlock = mainFunc.body().entryBlock(); 152 153 var model = buildModel( 154 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer), 155 imports, 156 functions.stream().map(f -> 157 function(domain, imports, f.funcName(), 158 expandTuples(indexer, f.parameters()), 159 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()), 160 nodes(domain, indexer, f.body().entryBlock().ops()))).toList()); 161 162 // System.out.println(OnnxModel.readFrom(model).toText()); 163 return model; 164 } 165 166 // @@@ unchecked constraints: 167 // tensor FuncOp parameters and single tensor return type 168 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor) 169 // entry block only 170 static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) { 171 var indexer = new Indexer(block.ancestorOp(), Map.of()); 172 var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of()); 173 // System.out.println(OnnxModel.readFrom(model).toText()); 174 return model; 175 } 176 177 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 178 return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of()); 179 } 180 181 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) { 182 return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions); 183 } 184 185 static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) { 186 return new ModelProto() 187 .irVersion(IR_VERSION) 188 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 189 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1))) 190 .forEach(functions, (m, f) -> m.functions(f)) 191 .graph(graph) 192 .getBytes(); 193 } 194 195 static List<String> expandTuples(Indexer indexer, List<? extends Value> values) { 196 var names = new ArrayList<String>(); 197 expandTuples(indexer, names, values); 198 return names; 199 } 200 201 static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) { 202 for (var v : values) { 203 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) { 204 expandTuples(indexer, names, op.operands()); 205 } else if (v.type() instanceof TupleType tt) { 206 var ct = tt.componentTypes(); 207 for (int i = 0; i < ct.size(); i++) { 208 names.add(indexer.nameOf(v, i)); 209 } 210 } else { 211 names.add(indexer.nameOf(v)); 212 } 213 } 214 } 215 216 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) { 217 return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null); 218 } 219 220 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 221 var params = block.parameters(); 222 params.forEach(indexer::nameOf); 223 int firstInitializer = params.size() - initializers.size(); 224 var args = params.subList(0, firstInitializer); 225 return graph(graphName, 226 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> { 227 Object val = initializers.get(i); 228 if (val instanceof Record) { 229 var rcs = val.getClass().getRecordComponents(); 230 for (int rci = 0; rci < rcs.length; rci++) try { 231 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer)); 232 } catch (ReflectiveOperationException e) { 233 throw new IllegalArgumentException(e); 234 } 235 } else if (val instanceof Tensor[] tarr) { 236 for (int tai = 0; tai < tarr.length; tai++) { 237 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer)); 238 } 239 } else { 240 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer)); 241 } 242 }).toList(), 243 tensorInfos(indexer, args, scalarArgs), 244 nodes(domain, indexer, block.ops()), 245 expandTuples(indexer, block.terminatingOp().operands())); 246 } 247 248 static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) { 249 List<String> inputNames = inputs.sequencedValues().stream() 250 .<String>mapMulti((v, dump) -> { 251 switch (v) { 252 case Value val -> dump.accept(indexer.nameOf(val)); 253 case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val)); 254 case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val))); 255 default -> dump.accept(""); // empty names for unused optional inputs 256 } 257 }).toList(); 258 // trim trailing empty names 259 return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed(); 260 } 261 262 static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) { 263 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> { 264 switch (op) { 265 case OnnxOps.If ifOp -> 266 opNodes.accept(node( 267 ifOp.schema().name(), 268 List.of(indexer.nameOf(ifOp.operands().getFirst())), 269 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(), 270 java.util.Map.of( 271 "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0), 272 "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0)))); 273 case OnnxOps.Loop loopOp -> { 274 opNodes.accept(node(loopOp.schema().name(), 275 expandTuples(indexer, loopOp.operands()), 276 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(), 277 java.util.Map.of( 278 "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2)))); 279 } 280 case OnnxOp onnxOp -> 281 opNodes.accept(node( 282 onnxOp.schema().name(), 283 opInputNames(indexer, onnxOp.onnxInputs()), 284 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(), 285 onnxOp.onnxAttributes())); 286 case CoreOp.FuncCallOp fco -> 287 opNodes.accept(node( 288 domain, 289 fco.funcName(), 290 expandTuples(indexer, fco.operands()), 291 expandTuples(indexer, List.of(fco.result())), 292 java.util.Map.of())); 293 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip 294 } 295 case CoreOp.TupleLoadOp tlo -> 296 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index()); 297 case CoreOp.TupleOp to -> 298 indexer.mapTupleElements(to.result(), to.operands()); 299 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> { 300 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) { 301 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i); 302 } else if (io.invokeDescriptor().name().equals("of")) { 303 indexer.mapTupleElements(io.result(), io.operands()); 304 } else { 305 throw new UnsupportedOperationException(op.toText()); 306 } 307 } 308 default -> { 309 throw new UnsupportedOperationException(op.toText()); 310 } 311 } 312 }).toList(); 313 } 314 315 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) { 316 var infos = new ArrayList<ValueInfoProto>(); 317 for (var arg : args) { 318 switch (arg.type()) { 319 case OnnxType.TensorType tt -> 320 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs)); 321 case TupleType tt -> { 322 var ct = tt.componentTypes(); 323 for (int i = 0; i < ct.size(); i++) { 324 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs)); 325 } 326 } 327 default -> 328 throw new UnsupportedOperationException(arg.type().toString()); 329 } 330 } 331 return infos; 332 } 333 334 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 335 return new GraphProto() 336 .name(name) 337 .forEach(initializers, (g, i) -> g.initializer(i)) 338 .forEach(inputs, (g, i) -> g.input(i)) 339 .forEach(ops, (g, op) -> g.node(op)) 340 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))); 341 } 342 343 static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) { 344 int di = functionName.lastIndexOf('.'); 345 return new FunctionProto() 346 .domain(functionDomain) 347 .name(functionName) 348 .forEach(inputNames, (f, i) -> f.input(i)) 349 .forEach(ops, (g, op) -> g.node(op)) 350 .forEach(outputNames, (f, o) -> f.output(o)) 351 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 352 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1))); 353 } 354 355 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 356 return new NodeProto() 357 .domain(domain) 358 .opType(opName) 359 .forEach(inputNames, (n, iName) -> n.input(iName)) 360 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 361 .forEach(outputNames, (n, oName) -> n.output(oName)); 362 } 363 364 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 365 int di = opName.lastIndexOf('.'); 366 return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes); 367 } 368 369 static ValueInfoProto tensorInfo(String name, int tensorElementType) { 370 return tensorInfo(name, tensorElementType, false); 371 } 372 373 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) { 374 var t = new TypeProto.Tensor().elemType(tensorElementType); 375 if (addScalarShape) t.shape(new TensorShapeProto()); 376 return new ValueInfoProto() 377 .name(name) 378 .type(new TypeProto().tensorType(t)); 379 } 380 381 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 382 ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor); 383 TensorProto tp = new TensorProto() 384 .name(name) 385 .dataType(tensor.elementType().id) 386 .dims(tensor.shape()); 387 return extInfo == null 388 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)) 389 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location())) 390 .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset()))) 391 .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length()))) 392 .dataLocation(DataLocation.EXTERNAL); 393 } 394 395 static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) { 396 return new TensorProto() 397 .dataType(tensor.elementType().id) 398 .dims(tensor.shape()) 399 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)); 400 } 401 402 static AttributeProto attribute(String name, Object value) { 403 var attr = new AttributeProto().name(name); 404 switch (value) { 405 case Float f -> { 406 attr.type(AttributeType.FLOAT).f(f); 407 } 408 case Long l -> { 409 attr.type(AttributeType.INT).i(l); 410 } 411 case GraphProto g -> { 412 attr.type(AttributeType.GRAPH).g(g.name(name)); 413 } 414 case float[] floats -> { 415 attr.type(AttributeType.FLOATS); 416 attr.floats(floats); 417 } 418 case long[] longs -> { 419 attr.type(AttributeType.INTS); 420 attr.ints(longs); 421 } 422 case Tensor<?> t -> { 423 attr.type(AttributeType.TENSOR); 424 attr.t(tensorProto(t)); 425 } 426 default -> { 427 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo 428 } 429 } 430 return attr; 431 } 432 }