1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.onnx; 27 28 import java.lang.foreign.ValueLayout; 29 import java.util.ArrayList; 30 import java.util.HashMap; 31 import java.util.List; 32 import java.util.Map; 33 import java.util.Optional; 34 import java.util.SequencedMap; 35 import java.util.function.Function; 36 import java.util.stream.IntStream; 37 import jdk.incubator.code.Block; 38 import jdk.incubator.code.CodeItem; 39 import jdk.incubator.code.Op; 40 import jdk.incubator.code.Value; 41 import jdk.incubator.code.op.CoreOp; 42 import jdk.incubator.code.type.JavaType; 43 import jdk.incubator.code.type.TupleType; 44 import jdk.incubator.code.writer.OpWriter; 45 import oracle.code.onnx.ir.OnnxOp; 46 import oracle.code.onnx.ir.OnnxOps; 47 import oracle.code.onnx.ir.OnnxType; 48 import oracle.code.onnx.proto.OnnxBuilder.*; 49 import oracle.code.onnx.proto.OnnxConstants.*; 50 import oracle.code.onnx.proto.OnnxModel; 51 52 public final class OnnxProtoBuilder { 53 54 static final int IR_VERSION = 10; 55 static final int OPSET_VERSION = 21; 56 57 private static final class Indexer { 58 59 private final Function<CodeItem, String> baseNames; 60 private final HashMap<String, String> remap; 61 62 63 Indexer(Op root, Map<Value, String> explicitNames) { 64 this.baseNames = OpWriter.computeGlobalNames(root); 65 this.remap = new HashMap<>(); 66 explicitNames.forEach(this::setName); 67 } 68 69 void setName(Value val, String name) { 70 remap.put(baseName(val), name); 71 if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) { 72 Value tr = tlo.operands().getFirst(); 73 remap.put(baseName(tr, tlo.index()), name); 74 if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) { 75 setName(to.operands().get(tlo.index()), name); 76 } 77 } 78 } 79 80 private String baseName(Value value) { 81 return "%" + baseNames.apply(value); 82 } 83 84 private String baseName(Value value, int elementIndex) { 85 var name = baseName(value); 86 return elementIndex > 0 ? name + '.' + elementIndex : name; 87 } 88 89 String nameOf(Value value) { 90 var name = baseName(value); 91 return remap.getOrDefault(name, name); 92 } 93 94 String nameOf(Value tuple, int elementIndex) { 95 var name = baseName(tuple, elementIndex); 96 return remap.getOrDefault(name, name); 97 } 98 99 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) { 100 remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex)); 101 } 102 103 void mapTupleElements(Value tuple, List<Value> elements) { 104 for (int i = 0; i < elements.size(); i++) { 105 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i))); 106 } 107 } 108 } 109 110 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) { 111 return buildModel(domain, module, initializers, Map.of(), _ -> null); 112 } 113 114 public record ExternalTensorDataInfo(String location, long offset, long length) { 115 } 116 117 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 118 var indexer = new Indexer(module, explicitValueNames); 119 120 var functions = new ArrayList<>(module.functionTable().sequencedValues()); 121 var imports = new ArrayList<String>(); 122 if (functions.size() > 1) imports.add(domain); // self domain import if additional functions 123 for (var f : functions) { 124 for (var op : f.body().entryBlock().ops()) { // auto import of op domains 125 if (op instanceof OnnxOp) { 126 int di = op.opName().lastIndexOf('.'); 127 if (di > 0) { 128 String dn = op.opName().substring(0, di); 129 if (!imports.contains(dn)) imports.add(dn); 130 } 131 } 132 } 133 } 134 var mainFunc = functions.removeLast(); 135 var mainBlock = mainFunc.body().entryBlock(); 136 137 var model = buildModel( 138 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer), 139 imports, 140 functions.stream().map(f -> 141 function(domain, imports, f.funcName(), 142 expandTuples(indexer, f.parameters()), 143 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()), 144 nodes(domain, indexer, f.body().entryBlock().ops()))).toList()); 145 146 // System.out.println(OnnxModel.readFrom(model).toText()); 147 return model; 148 } 149 150 // @@@ unchecked constraints: 151 // tensor FuncOp parameters and single tensor return type 152 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor) 153 // entry block only 154 static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) { 155 var indexer = new Indexer(block.parentBody().parentOp(), Map.of()); 156 var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of()); 157 // System.out.println(OnnxModel.readFrom(model).toText()); 158 return model; 159 } 160 161 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 162 return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of()); 163 } 164 165 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) { 166 return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions); 167 } 168 169 static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) { 170 return new ModelProto() 171 .irVersion(IR_VERSION) 172 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 173 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1))) 174 .forEach(functions, (m, f) -> m.functions(f)) 175 .graph(graph) 176 .getBytes(); 177 } 178 179 static List<String> expandTuples(Indexer indexer, List<? extends Value> values) { 180 var names = new ArrayList<String>(); 181 expandTuples(indexer, names, values); 182 return names; 183 } 184 185 static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) { 186 for (var v : values) { 187 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) { 188 expandTuples(indexer, names, op.operands()); 189 } else if (v.type() instanceof TupleType tt) { 190 var ct = tt.componentTypes(); 191 for (int i = 0; i < ct.size(); i++) { 192 names.add(indexer.nameOf(v, i)); 193 } 194 } else { 195 names.add(indexer.nameOf(v)); 196 } 197 } 198 } 199 200 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) { 201 return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null); 202 } 203 204 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 205 var params = block.parameters(); 206 params.forEach(indexer::nameOf); 207 int firstInitializer = params.size() - initializers.size(); 208 var args = params.subList(0, firstInitializer); 209 return graph(graphName, 210 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> { 211 Object val = initializers.get(i); 212 if (val instanceof Record) { 213 var rcs = val.getClass().getRecordComponents(); 214 for (int rci = 0; rci < rcs.length; rci++) try { 215 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer)); 216 } catch (ReflectiveOperationException e) { 217 throw new IllegalArgumentException(e); 218 } 219 } else { 220 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer)); 221 } 222 }).toList(), 223 tensorInfos(indexer, args, scalarArgs), 224 nodes(domain, indexer, block.ops()), 225 expandTuples(indexer, block.terminatingOp().operands())); 226 } 227 228 static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) { 229 List<String> inputNames = inputs.sequencedValues().stream() 230 .<String>mapMulti((v, dump) -> { 231 switch (v) { 232 case Value val -> dump.accept(indexer.nameOf(val)); 233 case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val)); 234 case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val))); 235 default -> dump.accept(""); // empty names for unused optional inputs 236 } 237 }).toList(); 238 // trim trailing empty names 239 return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed(); 240 } 241 242 static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) { 243 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> { 244 switch (op) { 245 case OnnxOps.If ifOp -> 246 opNodes.accept(node( 247 ifOp.opName(), 248 List.of(indexer.nameOf(ifOp.operands().getFirst())), 249 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(), 250 java.util.Map.of( 251 "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0), 252 "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0)))); 253 case OnnxOps.Loop loopOp -> { 254 opNodes.accept(node(loopOp.opName(), 255 expandTuples(indexer, loopOp.operands()), 256 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(), 257 java.util.Map.of( 258 "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2)))); 259 } 260 case OnnxOp onnxOp -> 261 opNodes.accept(node( 262 onnxOp.opName(), 263 opInputNames(indexer, onnxOp.onnxInputs()), 264 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(), 265 onnxOp.onnxAttributes())); 266 case CoreOp.FuncCallOp fco -> 267 opNodes.accept(node( 268 domain, 269 fco.funcName(), 270 expandTuples(indexer, fco.operands()), 271 expandTuples(indexer, List.of(fco.result())), 272 java.util.Map.of())); 273 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip 274 } 275 case CoreOp.TupleLoadOp tlo -> 276 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index()); 277 case CoreOp.TupleOp to -> 278 indexer.mapTupleElements(to.result(), to.operands()); 279 case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> { 280 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) { 281 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i); 282 } else if (io.invokeDescriptor().name().equals("of")) { 283 indexer.mapTupleElements(io.result(), io.operands()); 284 } else { 285 throw new UnsupportedOperationException(op.toText()); 286 } 287 } 288 default -> { 289 throw new UnsupportedOperationException(op.toText()); 290 } 291 } 292 }).toList(); 293 } 294 295 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) { 296 var infos = new ArrayList<ValueInfoProto>(); 297 for (var arg : args) { 298 switch (arg.type()) { 299 case OnnxType.TensorType tt -> 300 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs)); 301 case TupleType tt -> { 302 var ct = tt.componentTypes(); 303 for (int i = 0; i < ct.size(); i++) { 304 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs)); 305 } 306 } 307 default -> 308 throw new UnsupportedOperationException(arg.type().toString()); 309 } 310 } 311 return infos; 312 } 313 314 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 315 return new GraphProto() 316 .name(name) 317 .forEach(initializers, (g, i) -> g.initializer(i)) 318 .forEach(inputs, (g, i) -> g.input(i)) 319 .forEach(ops, (g, op) -> g.node(op)) 320 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))); 321 } 322 323 static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) { 324 int di = functionName.lastIndexOf('.'); 325 return new FunctionProto() 326 .domain(functionDomain) 327 .name(functionName) 328 .forEach(inputNames, (f, i) -> f.input(i)) 329 .forEach(ops, (g, op) -> g.node(op)) 330 .forEach(outputNames, (f, o) -> f.output(o)) 331 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 332 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1))); 333 } 334 335 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 336 return new NodeProto() 337 .domain(domain) 338 .opType(opName) 339 .forEach(inputNames, (n, iName) -> n.input(iName)) 340 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 341 .forEach(outputNames, (n, oName) -> n.output(oName)); 342 } 343 344 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 345 int di = opName.lastIndexOf('.'); 346 return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes); 347 } 348 349 static ValueInfoProto tensorInfo(String name, int tensorElementType) { 350 return tensorInfo(name, tensorElementType, false); 351 } 352 353 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) { 354 var t = new TypeProto.Tensor().elemType(tensorElementType); 355 if (addScalarShape) t.shape(new TensorShapeProto()); 356 return new ValueInfoProto() 357 .name(name) 358 .type(new TypeProto().tensorType(t)); 359 } 360 361 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) { 362 ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor); 363 TensorProto tp = new TensorProto() 364 .name(name) 365 .dataType(tensor.elementType().id) 366 .dims(tensor.shape()); 367 return extInfo == null 368 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)) 369 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location())) 370 .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset()))) 371 .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length()))) 372 .dataLocation(DataLocation.EXTERNAL); 373 } 374 375 static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) { 376 return new TensorProto() 377 .dataType(tensor.elementType().id) 378 .dims(tensor.shape()) 379 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)); 380 } 381 382 static AttributeProto attribute(String name, Object value) { 383 var attr = new AttributeProto().name(name); 384 switch (value) { 385 case Float f -> { 386 attr.type(AttributeType.FLOAT).f(f); 387 } 388 case Long l -> { 389 attr.type(AttributeType.INT).i(l); 390 } 391 case GraphProto g -> { 392 attr.type(AttributeType.GRAPH).g(g.name(name)); 393 } 394 case float[] floats -> { 395 attr.type(AttributeType.FLOATS); 396 attr.floats(floats); 397 } 398 case long[] longs -> { 399 attr.type(AttributeType.INTS); 400 attr.ints(longs); 401 } 402 case Tensor<?> t -> { 403 attr.type(AttributeType.TENSOR); 404 attr.t(tensorProto(t)); 405 } 406 default -> { 407 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo 408 } 409 } 410 return attr; 411 } 412 }