1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.onnx; 27 28 import java.lang.foreign.ValueLayout; 29 import java.util.ArrayList; 30 import java.util.HashMap; 31 import java.util.List; 32 import java.util.Map; 33 import java.util.Optional; 34 import java.util.SequencedMap; 35 import java.util.function.Function; 36 import java.util.stream.IntStream; 37 import jdk.incubator.code.Block; 38 import jdk.incubator.code.CodeItem; 39 import jdk.incubator.code.Op; 40 import jdk.incubator.code.Value; 41 import jdk.incubator.code.dialect.core.CoreOp; 42 import jdk.incubator.code.dialect.core.TupleType; 43 import jdk.incubator.code.dialect.java.JavaOp; 44 import jdk.incubator.code.dialect.java.JavaType; 45 import jdk.incubator.code.extern.OpWriter; 46 import oracle.code.onnx.ir.OnnxOp; 47 import oracle.code.onnx.ir.OnnxOps; 48 import oracle.code.onnx.ir.OnnxType; 49 import oracle.code.onnx.proto.OnnxBuilder.*; 50 import oracle.code.onnx.proto.OnnxConstants.*; 51 52 public final class OnnxProtoBuilder { 53 54 static final int IR_VERSION = 10; 55 static final int OPSET_VERSION = 21; 56 57 private static final class Indexer { 58 59 private final Function<CodeItem, String> baseNames; 60 private final HashMap<String, String> remap; 61 62 63 Indexer(Op root, Map<Value, String> explicitNames) { 64 this.baseNames = OpWriter.computeGlobalNames(root); 65 this.remap = new HashMap<>(); 66 explicitNames.forEach(this::setName); 67 } 68 69 void setName(Value val, String name) { 70 switch (val) { 71 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> { 72 remap.put(baseName(val), name); 73 for (int i = 0; i < to.operands().size(); i++) { 74 setName(to.operands().get(i), name + "." + i); 75 } 76 } 77 case Block.Parameter bp when val.type() instanceof TupleType tt -> { 78 for (int i = 0; i < tt.componentTypes().size(); i++) { 79 remap.put(baseName(val, i), name +"." + i); 80 } 81 } 82 default -> { 83 remap.put(baseName(val), name); 84 if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) { 85 Value tr = tlo.operands().getFirst(); 86 remap.put(baseName(tr, tlo.index()), name); 87 if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) { 88 setName(to.operands().get(tlo.index()), name); 89 } 90 } 91 } 92 } 93 } 94 95 private String baseName(Value value) { 96 return "%" + baseNames.apply(value); 97 } 98 99 private String baseName(Value value, int elementIndex) { 100 var name = baseName(value); 101 return elementIndex > 0 ? name + '.' + elementIndex : name; 102 } 103 104 String nameOf(Value value) { 105 var name = baseName(value); 106 return remap.getOrDefault(name, name); 107 } 108 109 String nameOf(Value tuple, int elementIndex) { 110 var name = baseName(tuple, elementIndex); 111 return remap.getOrDefault(name, name); 112 } 113 114 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) { 115 remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex)); 116 } 117 118 void mapTupleElements(Value tuple, List<Value> elements) { 119 for (int i = 0; i < elements.size(); i++) { 120 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i))); 121 } 122 } 123 } 124 125 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) { 126 return buildModel(domain, module, initializers, Map.of(), _ -> null); 127 } 128 129 public record ExternalTensorDataInfo(String location, long offset, long length) { 130 } 131 132 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 133 var indexer = new Indexer(module, explicitValueNames); 134 135 var functions = new ArrayList<>(module.functionTable().sequencedValues()); 136 var imports = new ArrayList<String>(); 137 if (functions.size() > 1) imports.add(domain); // self domain import if additional functions 138 for (var f : functions) { 139 for (var op : f.body().entryBlock().ops()) { // auto import of op domains 140 if (op instanceof OnnxOp) { 141 int di = op.opName().lastIndexOf('.'); 142 if (di > 0) { 143 String dn = op.opName().substring(0, di); 144 if (!imports.contains(dn)) imports.add(dn); 145 } 146 } 147 } 148 } 149 var mainFunc = functions.removeLast(); 150 var mainBlock = mainFunc.body().entryBlock(); 151 152 var model = buildModel( 153 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer), 154 imports, 155 functions.stream().map(f -> 156 function(domain, imports, f.funcName(), 157 expandTuples(indexer, f.parameters()), 158 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()), 159 nodes(domain, indexer, f.body().entryBlock().ops()))).toList()); 160 161 // System.out.println(OnnxModel.readFrom(model).toText()); 162 return model; 163 } 164 165 // @@@ unchecked constraints: 166 // tensor FuncOp parameters and single tensor return type 167 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor) 168 // entry block only 169 static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) { 170 var indexer = new Indexer(block.parentBody().parentOp(), Map.of()); 171 var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of()); 172 // System.out.println(OnnxModel.readFrom(model).toText()); 173 return model; 174 } 175 176 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 177 return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of()); 178 } 179 180 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) { 181 return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions); 182 } 183 184 static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) { 185 return new ModelProto() 186 .irVersion(IR_VERSION) 187 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 188 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1))) 189 .forEach(functions, (m, f) -> m.functions(f)) 190 .graph(graph) 191 .getBytes(); 192 } 193 194 static List<String> expandTuples(Indexer indexer, List<? extends Value> values) { 195 var names = new ArrayList<String>(); 196 expandTuples(indexer, names, values); 197 return names; 198 } 199 200 static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) { 201 for (var v : values) { 202 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) { 203 expandTuples(indexer, names, op.operands()); 204 } else if (v.type() instanceof TupleType tt) { 205 var ct = tt.componentTypes(); 206 for (int i = 0; i < ct.size(); i++) { 207 names.add(indexer.nameOf(v, i)); 208 } 209 } else { 210 names.add(indexer.nameOf(v)); 211 } 212 } 213 } 214 215 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) { 216 return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null); 217 } 218 219 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 220 var params = block.parameters(); 221 params.forEach(indexer::nameOf); 222 int firstInitializer = params.size() - initializers.size(); 223 var args = params.subList(0, firstInitializer); 224 return graph(graphName, 225 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> { 226 Object val = initializers.get(i); 227 if (val instanceof Record) { 228 var rcs = val.getClass().getRecordComponents(); 229 for (int rci = 0; rci < rcs.length; rci++) try { 230 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer)); 231 } catch (ReflectiveOperationException e) { 232 throw new IllegalArgumentException(e); 233 } 234 } else if (val instanceof Tensor[] tarr) { 235 for (int tai = 0; tai < tarr.length; tai++) { 236 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer)); 237 } 238 } else { 239 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer)); 240 } 241 }).toList(), 242 tensorInfos(indexer, args, scalarArgs), 243 nodes(domain, indexer, block.ops()), 244 expandTuples(indexer, block.terminatingOp().operands())); 245 } 246 247 static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) { 248 List<String> inputNames = inputs.sequencedValues().stream() 249 .<String>mapMulti((v, dump) -> { 250 switch (v) { 251 case Value val -> dump.accept(indexer.nameOf(val)); 252 case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val)); 253 case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val))); 254 default -> dump.accept(""); // empty names for unused optional inputs 255 } 256 }).toList(); 257 // trim trailing empty names 258 return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed(); 259 } 260 261 static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) { 262 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> { 263 switch (op) { 264 case OnnxOps.If ifOp -> 265 opNodes.accept(node( 266 ifOp.opName(), 267 List.of(indexer.nameOf(ifOp.operands().getFirst())), 268 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(), 269 java.util.Map.of( 270 "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0), 271 "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0)))); 272 case OnnxOps.Loop loopOp -> { 273 opNodes.accept(node(loopOp.opName(), 274 expandTuples(indexer, loopOp.operands()), 275 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(), 276 java.util.Map.of( 277 "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2)))); 278 } 279 case OnnxOp onnxOp -> 280 opNodes.accept(node( 281 onnxOp.opName(), 282 opInputNames(indexer, onnxOp.onnxInputs()), 283 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(), 284 onnxOp.onnxAttributes())); 285 case CoreOp.FuncCallOp fco -> 286 opNodes.accept(node( 287 domain, 288 fco.funcName(), 289 expandTuples(indexer, fco.operands()), 290 expandTuples(indexer, List.of(fco.result())), 291 java.util.Map.of())); 292 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip 293 } 294 case CoreOp.TupleLoadOp tlo -> 295 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index()); 296 case CoreOp.TupleOp to -> 297 indexer.mapTupleElements(to.result(), to.operands()); 298 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> { 299 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) { 300 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i); 301 } else if (io.invokeDescriptor().name().equals("of")) { 302 indexer.mapTupleElements(io.result(), io.operands()); 303 } else { 304 throw new UnsupportedOperationException(op.toText()); 305 } 306 } 307 default -> { 308 throw new UnsupportedOperationException(op.toText()); 309 } 310 } 311 }).toList(); 312 } 313 314 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) { 315 var infos = new ArrayList<ValueInfoProto>(); 316 for (var arg : args) { 317 switch (arg.type()) { 318 case OnnxType.TensorType tt -> 319 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs)); 320 case TupleType tt -> { 321 var ct = tt.componentTypes(); 322 for (int i = 0; i < ct.size(); i++) { 323 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs)); 324 } 325 } 326 default -> 327 throw new UnsupportedOperationException(arg.type().toString()); 328 } 329 } 330 return infos; 331 } 332 333 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 334 return new GraphProto() 335 .name(name) 336 .forEach(initializers, (g, i) -> g.initializer(i)) 337 .forEach(inputs, (g, i) -> g.input(i)) 338 .forEach(ops, (g, op) -> g.node(op)) 339 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))); 340 } 341 342 static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) { 343 int di = functionName.lastIndexOf('.'); 344 return new FunctionProto() 345 .domain(functionDomain) 346 .name(functionName) 347 .forEach(inputNames, (f, i) -> f.input(i)) 348 .forEach(ops, (g, op) -> g.node(op)) 349 .forEach(outputNames, (f, o) -> f.output(o)) 350 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 351 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1))); 352 } 353 354 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 355 return new NodeProto() 356 .domain(domain) 357 .opType(opName) 358 .forEach(inputNames, (n, iName) -> n.input(iName)) 359 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 360 .forEach(outputNames, (n, oName) -> n.output(oName)); 361 } 362 363 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 364 int di = opName.lastIndexOf('.'); 365 return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes); 366 } 367 368 static ValueInfoProto tensorInfo(String name, int tensorElementType) { 369 return tensorInfo(name, tensorElementType, false); 370 } 371 372 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) { 373 var t = new TypeProto.Tensor().elemType(tensorElementType); 374 if (addScalarShape) t.shape(new TensorShapeProto()); 375 return new ValueInfoProto() 376 .name(name) 377 .type(new TypeProto().tensorType(t)); 378 } 379 380 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 381 ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor); 382 TensorProto tp = new TensorProto() 383 .name(name) 384 .dataType(tensor.elementType().id) 385 .dims(tensor.shape()); 386 return extInfo == null 387 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)) 388 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location())) 389 .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset()))) 390 .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length()))) 391 .dataLocation(DataLocation.EXTERNAL); 392 } 393 394 static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) { 395 return new TensorProto() 396 .dataType(tensor.elementType().id) 397 .dims(tensor.shape()) 398 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)); 399 } 400 401 static AttributeProto attribute(String name, Object value) { 402 var attr = new AttributeProto().name(name); 403 switch (value) { 404 case Float f -> { 405 attr.type(AttributeType.FLOAT).f(f); 406 } 407 case Long l -> { 408 attr.type(AttributeType.INT).i(l); 409 } 410 case GraphProto g -> { 411 attr.type(AttributeType.GRAPH).g(g.name(name)); 412 } 413 case float[] floats -> { 414 attr.type(AttributeType.FLOATS); 415 attr.floats(floats); 416 } 417 case long[] longs -> { 418 attr.type(AttributeType.INTS); 419 attr.ints(longs); 420 } 421 case Tensor<?> t -> { 422 attr.type(AttributeType.TENSOR); 423 attr.t(tensorProto(t)); 424 } 425 default -> { 426 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo 427 } 428 } 429 return attr; 430 } 431 }