1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package oracle.code.onnx; 27 28 import java.lang.foreign.ValueLayout; 29 import java.util.ArrayList; 30 import java.util.HashMap; 31 import java.util.List; 32 import java.util.function.Function; 33 import java.util.stream.IntStream; 34 import jdk.incubator.code.Block; 35 import jdk.incubator.code.CodeItem; 36 import jdk.incubator.code.Op; 37 import jdk.incubator.code.Value; 38 import jdk.incubator.code.op.CoreOp; 39 import jdk.incubator.code.type.JavaType; 40 import jdk.incubator.code.type.TupleType; 41 import jdk.incubator.code.writer.OpWriter; 42 import oracle.code.onnx.ir.OnnxOp; 43 import oracle.code.onnx.ir.OnnxOps; 44 import oracle.code.onnx.ir.OnnxType; 45 import oracle.code.onnx.proto.OnnxBuilder.*; 46 import oracle.code.onnx.proto.OnnxConstants.*; 47 48 public final class OnnxProtoBuilder { 49 50 static final int IR_VERSION = 10; 51 static final int OPSET_VERSION = 21; 52 53 private static final class Indexer { 54 55 private final Function<CodeItem, String> baseNames; 56 private final HashMap<String, String> elementsMap; 57 58 59 Indexer(Function<CodeItem, String> baseNames) { 60 this.baseNames = baseNames; 61 this.elementsMap = new HashMap<>(); 62 } 63 64 private String baseName(Value value, int elementIndex) { 65 var name = "%" + baseNames.apply(value); 66 return elementIndex > 0 ? name + '.' + elementIndex : name; 67 } 68 69 String nameOf(Value value) { 70 return nameOf(value, 0); 71 } 72 73 String nameOf(Value tuple, int elementIndex) { 74 var name = baseName(tuple, elementIndex); 75 return elementsMap.getOrDefault(name, name); 76 } 77 78 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) { 79 elementsMap.put(baseName(tupleLoadResult, 0), nameOf(tuple, elementIndex)); 80 } 81 82 void mapTupleElements(Value tuple, List<Value> elements) { 83 for (int i = 0; i < elements.size(); i++) { 84 elementsMap.put(baseName(tuple, i), nameOf(elements.get(i))); 85 } 86 } 87 } 88 89 static byte[] build(String domainName, CoreOp.ModuleOp module, List<oracle.code.onnx.Tensor> initializers) { 90 var indexer = new Indexer(OpWriter.computeGlobalNames(module)); 91 92 var functions = new ArrayList<>(module.functionTable().sequencedValues()); 93 var mainFunc = functions.removeLast(); 94 var mainBlock = mainFunc.body().entryBlock(); 95 96 var model = build( 97 graph(mainFunc.funcName(), domainName, indexer, mainBlock, initializers, 0), 98 List.of(domainName), 99 functions.stream().map(f -> 100 function(domainName, 101 f.funcName(), 102 f.parameters().stream().map(indexer::nameOf).toList(), 103 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()), 104 nodes(domainName, indexer, f.body().entryBlock().ops()))).toList()); 105 106 // OnnxProtoPrinter.printModel(model); 107 return model; 108 } 109 110 // @@@ unchecked constraints: 111 // tensor FuncOp parameters and single tensor return type 112 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor) 113 // entry block only 114 static byte[] build(Block block, List<oracle.code.onnx.Tensor> initializers) { 115 var indexer = new Indexer(OpWriter.computeGlobalNames(block.parentBody().parentOp())); 116 var model = build(graph(null, null, indexer, block, initializers, 0), List.of(), List.of()); 117 // OnnxProtoPrinter.printModel(model); 118 return model; 119 } 120 121 static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 122 return build(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of()); 123 } 124 125 static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) { 126 return build(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions); 127 } 128 129 static byte[] build(GraphProto graph, List<String> customImportDomains, List<FunctionProto> functions) { 130 return new ModelProto() 131 .irVersion(IR_VERSION) 132 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)) 133 .forEach(customImportDomains, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d))) 134 .forEach(functions, (m, f) -> m.functions(f)) 135 .graph(graph) 136 .getBytes(); 137 } 138 139 static List<String> expandTuples(Indexer indexer, List<Value> values) { 140 var names = new ArrayList<String>(); 141 expandTuples(indexer, names, values); 142 return names; 143 } 144 145 static void expandTuples(Indexer indexer, List<String> names, List<Value> values) { 146 for (var v : values) { 147 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) { 148 expandTuples(indexer, names, op.operands()); 149 } else if (v.type() instanceof TupleType tt) { 150 var ct = tt.componentTypes(); 151 for (int i = 0; i < ct.size(); i++) { 152 names.add(indexer.nameOf(v, i)); 153 } 154 } else { 155 names.add(indexer.nameOf(v)); 156 } 157 } 158 } 159 160 static GraphProto graph(String graphName, String domainName, Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers, int scalarArgs) { 161 var params = block.parameters(); 162 params.forEach(indexer::nameOf); 163 int firstInitializer = params.size() - initializers.size(); 164 var args = params.subList(0, firstInitializer); 165 return graph(graphName, 166 IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.nameOf(params.get(i + firstInitializer)), initializers.get(i))).toList(), 167 tensorInfos(indexer, args, scalarArgs), 168 nodes(domainName, indexer, block.ops()), 169 expandTuples(indexer, block.terminatingOp().operands())); 170 } 171 172 static List<NodeProto> nodes(String domainName, Indexer indexer, List<Op> ops) { 173 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> { 174 switch (op) { 175 case OnnxOps.If ifOp -> 176 opNodes.accept(node( 177 ifOp.opName(), 178 List.of(indexer.nameOf(ifOp.operands().getFirst())), 179 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(), 180 java.util.Map.of( 181 "then_branch", graph(null, domainName, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0), 182 "else_branch", graph(null, domainName, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0)))); 183 case OnnxOps.Loop loopOp -> { 184 opNodes.accept(node(loopOp.opName(), 185 expandTuples(indexer, loopOp.operands()), 186 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(), 187 java.util.Map.of( 188 "body", graph(null, domainName, indexer, loopOp.loopBody().entryBlock(), List.of(), 2)))); 189 } 190 case OnnxOp onnxOp -> 191 opNodes.accept(node( 192 onnxOp.opName(), 193 onnxOp.operands().stream().map(indexer::nameOf).toList(), 194 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(), 195 onnxOp.onnxAttributes())); 196 case CoreOp.FuncCallOp fco -> 197 opNodes.accept(node( 198 domainName, 199 fco.funcName(), 200 fco.operands().stream().map(indexer::nameOf).toList(), 201 expandTuples(indexer, List.of(fco.result())), 202 java.util.Map.of())); 203 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip 204 } 205 case CoreOp.TupleLoadOp tlo -> 206 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index()); 207 case CoreOp.TupleOp to -> 208 indexer.mapTupleElements(to.result(), to.operands()); 209 case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> { 210 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) { 211 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i); 212 } else if (io.invokeDescriptor().name().equals("of")) { 213 indexer.mapTupleElements(io.result(), io.operands()); 214 } else { 215 throw new UnsupportedOperationException(op.toText()); 216 } 217 } 218 default -> { 219 throw new UnsupportedOperationException(op.toText()); 220 } 221 } 222 }).toList(); 223 } 224 225 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) { 226 var infos = new ArrayList<ValueInfoProto>(); 227 for (var arg : args) { 228 switch (arg.type()) { 229 case OnnxType.TensorType tt -> 230 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs)); 231 case TupleType tt -> { 232 var ct = tt.componentTypes(); 233 for (int i = 0; i < ct.size(); i++) { 234 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs)); 235 } 236 } 237 default -> 238 throw new UnsupportedOperationException(arg.type().toString()); 239 } 240 } 241 return infos; 242 } 243 244 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) { 245 return new GraphProto() 246 .name(name) 247 .forEach(initializers, (g, i) -> g.initializer(i)) 248 .forEach(inputs, (g, i) -> g.input(i)) 249 .forEach(ops, (g, op) -> g.node(op)) 250 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))); 251 } 252 253 static FunctionProto function(String domain, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) { 254 return new FunctionProto() 255 .domain(domain) 256 .name(functionName) 257 .forEach(inputNames, (f, i) -> f.input(i)) 258 .forEach(ops, (g, op) -> g.node(op)) 259 .forEach(outputNames, (f, o) -> f.output(o)) 260 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION)); 261 } 262 263 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 264 return new NodeProto() 265 .domain(domain) 266 .opType(opName) 267 .forEach(inputNames, (n, iName) -> n.input(iName)) 268 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 269 .forEach(outputNames, (n, oName) -> n.output(oName)); 270 } 271 272 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) { 273 return new NodeProto() 274 .opType(opName) 275 .forEach(inputNames, (n, iName) -> n.input(iName)) 276 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue()))) 277 .forEach(outputNames, (n, oName) -> n.output(oName)); 278 } 279 280 static ValueInfoProto tensorInfo(String name, int tensorElementType) { 281 return tensorInfo(name, tensorElementType, false); 282 } 283 284 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) { 285 var t = new TypeProto.Tensor().elemType(tensorElementType); 286 if (addScalarShape) t.shape(new TensorShapeProto()); 287 return new ValueInfoProto() 288 .name(name) 289 .type(new TypeProto().tensorType(t)); 290 } 291 292 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor) { 293 return new TensorProto() 294 .name(name) 295 .dataType(tensor.elementType().id) 296 .dims(tensor.shape()) 297 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE)); 298 } 299 300 static AttributeProto attribute(String name, Object value) { 301 var attr = new AttributeProto().name(name); 302 switch (value) { 303 case Float f -> { 304 attr.type(AttributeType.FLOAT).f(f); 305 } 306 case Long l -> { 307 attr.type(AttributeType.INT).i(l); 308 } 309 case GraphProto g -> { 310 attr.type(AttributeType.GRAPH).g(g.name(name)); 311 } 312 case float[] floats -> { 313 attr.type(AttributeType.FLOATS); 314 attr.floats(floats); 315 } 316 case long[] longs -> { 317 attr.type(AttributeType.INTS); 318 attr.ints(longs); 319 } 320 default -> { 321 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo 322 } 323 } 324 return attr; 325 } 326 }