1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx;
27
28 import java.lang.foreign.ValueLayout;
29 import java.util.ArrayList;
30 import java.util.HashMap;
31 import java.util.List;
32 import java.util.Map;
33 import java.util.SequencedMap;
34 import java.util.function.Function;
35 import java.util.stream.IntStream;
36 import jdk.incubator.code.Block;
37 import jdk.incubator.code.CodeItem;
38 import jdk.incubator.code.Op;
39 import jdk.incubator.code.Value;
40 import jdk.incubator.code.dialect.core.CoreOp;
41 import jdk.incubator.code.dialect.core.TupleType;
42 import jdk.incubator.code.dialect.java.JavaOp;
43 import jdk.incubator.code.dialect.java.JavaType;
44 import jdk.incubator.code.extern.OpWriter;
45 import oracle.code.onnx.ir.OnnxOp;
46 import oracle.code.onnx.ir.OnnxOps;
47 import oracle.code.onnx.ir.OnnxType;
48 import oracle.code.onnx.proto.OnnxBuilder.*;
49 import oracle.code.onnx.proto.OnnxConstants.*;
50
51 public final class OnnxProtoBuilder {
52
53 static final int IR_VERSION = 10;
54 static final int OPSET_VERSION = 21;
55
56 private static final class Indexer {
57
58 private final Function<CodeItem, String> baseNames;
59 private final HashMap<String, String> remap;
60
61
62 Indexer(Op root, Map<Value, String> explicitNames) {
63 this.baseNames = OpWriter.computeGlobalNames(root);
64 this.remap = new HashMap<>();
65 explicitNames.forEach(this::setName);
66 }
67
68 void setName(Value val, String name) {
69 switch (val) {
70 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> {
71 remap.put(baseName(val), name);
72 for (int i = 0; i < to.operands().size(); i++) {
73 setName(to.operands().get(i), name + "." + i);
74 }
75 }
76 case Block.Parameter bp when val.type() instanceof TupleType tt -> {
77 for (int i = 0; i < tt.componentTypes().size(); i++) {
78 remap.put(baseName(val, i), name +"." + i);
79 }
80 }
81 default -> {
82 remap.put(baseName(val), name);
83 if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) {
84 Value tr = tlo.operands().getFirst();
85 remap.put(baseName(tr, tlo.index()), name);
86 if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) {
87 setName(to.operands().get(tlo.index()), name);
88 }
89 }
90 }
91 }
92 }
93
94 private String baseName(Value value) {
95 return "%" + baseNames.apply(value);
96 }
97
98 private String baseName(Value value, int elementIndex) {
99 var name = baseName(value);
100 return elementIndex > 0 ? name + '.' + elementIndex : name;
101 }
102
103 String nameOf(Value value) {
104 var name = baseName(value);
105 return remap.getOrDefault(name, name);
106 }
107
108 String nameOf(Value tuple, int elementIndex) {
109 var name = baseName(tuple, elementIndex);
110 return remap.getOrDefault(name, name);
111 }
112
113 void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
114 remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex));
115 }
116
117 void mapTupleElements(Value tuple, List<Value> elements) {
118 for (int i = 0; i < elements.size(); i++) {
119 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i)));
120 }
121 }
122 }
123
124 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) {
125 return buildModel(domain, module, initializers, Map.of(), _ -> null);
126 }
127
128 public record ExternalTensorDataInfo(String location, long offset, long length) {
129 }
130
131 public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
132 var indexer = new Indexer(module, explicitValueNames);
133
134 var functions = new ArrayList<>(module.functionTable().sequencedValues());
135 var imports = new ArrayList<String>();
136 if (functions.size() > 1) imports.add(domain); // self domain import if additional functions
137 for (var f : functions) {
138 for (var op : f.body().entryBlock().ops()) { // auto import of op domains
139 if (op instanceof OnnxOp oop) {
140 String name = oop.schema().name();
141 int di = name.lastIndexOf('.');
142 if (di > 0) {
143 String dn = name.substring(0, di);
144 if (!imports.contains(dn)) imports.add(dn);
145 }
146 }
147 }
148 }
149 var mainFunc = functions.removeLast();
150 var mainBlock = mainFunc.body().entryBlock();
151
152 var model = buildModel(
153 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer),
154 imports,
155 functions.stream().map(f ->
156 function(domain, imports, f.funcName(),
157 expandTuples(indexer, f.parameters()),
158 expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
159 nodes(domain, indexer, f.body().entryBlock().ops()))).toList());
160
161 // System.out.println(OnnxModel.readFrom(model).toText());
162 return model;
163 }
164
165 // @@@ unchecked constraints:
166 // tensor FuncOp parameters and single tensor return type
167 // OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
168 // entry block only
169 static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) {
170 var indexer = new Indexer(block.ancestorOp(), Map.of());
171 var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
172 // System.out.println(OnnxModel.readFrom(model).toText());
173 return model;
174 }
175
176 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
177 return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
178 }
179
180 static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
181 return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
182 }
183
184 static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) {
185 return new ModelProto()
186 .irVersion(IR_VERSION)
187 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
188 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1)))
189 .forEach(functions, (m, f) -> m.functions(f))
190 .graph(graph)
191 .getBytes();
192 }
193
194 static List<String> expandTuples(Indexer indexer, List<? extends Value> values) {
195 var names = new ArrayList<String>();
196 expandTuples(indexer, names, values);
197 return names;
198 }
199
200 static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) {
201 for (var v : values) {
202 if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
203 expandTuples(indexer, names, op.operands());
204 } else if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp op) {
205 names.add(indexer.nameOf(op.operands().getFirst(), op.index()));
206 } else if (v.type() instanceof TupleType tt) {
207 var ct = tt.componentTypes();
208 for (int i = 0; i < ct.size(); i++) {
209 names.add(indexer.nameOf(v, i));
210 }
211 } else {
212 names.add(indexer.nameOf(v));
213 }
214 }
215 }
216
217 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) {
218 return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null);
219 }
220
221 static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
222 var params = block.parameters();
223 params.forEach(indexer::nameOf);
224 int firstInitializer = params.size() - initializers.size();
225 var args = params.subList(0, firstInitializer);
226 return graph(graphName,
227 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> {
228 Object val = initializers.get(i);
229 if (val instanceof Record) {
230 var rcs = val.getClass().getRecordComponents();
231 for (int rci = 0; rci < rcs.length; rci++) try {
232 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer));
233 } catch (ReflectiveOperationException e) {
234 throw new IllegalArgumentException(e);
235 }
236 } else if (val instanceof Tensor[] tarr) {
237 for (int tai = 0; tai < tarr.length; tai++) {
238 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer));
239 }
240 } else {
241 tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer));
242 }
243 }).toList(),
244 tensorInfos(indexer, args, scalarArgs),
245 nodes(domain, indexer, block.ops()),
246 expandTuples(indexer, block.terminatingOp().operands()));
247 }
248
249 static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) {
250 List<String> inputNames = inputs.sequencedValues().stream()
251 .<String>mapMulti((v, dump) -> {
252 switch (v) {
253 case Value val -> dump.accept(indexer.nameOf(val));
254 case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val));
255 case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val)));
256 default -> dump.accept(""); // empty names for unused optional inputs
257 }
258 }).toList();
259 // trim trailing empty names
260 return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed();
261 }
262
263 static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) {
264 return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
265 switch (op) {
266 case OnnxOps.If ifOp ->
267 opNodes.accept(node(
268 ifOp.schema().name(),
269 List.of(indexer.nameOf(ifOp.operands().getFirst())),
270 IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
271 java.util.Map.of(
272 "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
273 "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
274 case OnnxOps.Loop loopOp -> {
275 opNodes.accept(node(loopOp.schema().name(),
276 expandTuples(indexer, loopOp.operands()),
277 IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
278 java.util.Map.of(
279 "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
280 }
281 case OnnxOp onnxOp ->
282 opNodes.accept(node(
283 onnxOp.schema().name(),
284 opInputNames(indexer, onnxOp.onnxInputs()),
285 IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
286 onnxOp.onnxAttributes()));
287 case CoreOp.FuncCallOp fco ->
288 opNodes.accept(node(
289 domain,
290 fco.funcName(),
291 expandTuples(indexer, fco.operands()),
292 expandTuples(indexer, List.of(fco.result())),
293 java.util.Map.of()));
294 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
295 }
296 case CoreOp.TupleLoadOp tlo ->
297 indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
298 case CoreOp.TupleOp to ->
299 indexer.mapTupleElements(to.result(), to.operands());
300 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
301 if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
302 indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
303 } else if (io.invokeDescriptor().name().equals("of")) {
304 indexer.mapTupleElements(io.result(), io.operands());
305 } else {
306 throw new UnsupportedOperationException(op.toText());
307 }
308 }
309 default -> {
310 throw new UnsupportedOperationException(op.toText());
311 }
312 }
313 }).toList();
314 }
315
316 static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
317 var infos = new ArrayList<ValueInfoProto>();
318 for (var arg : args) {
319 switch (arg.type()) {
320 case OnnxType.TensorType tt ->
321 infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
322 case TupleType tt -> {
323 var ct = tt.componentTypes();
324 for (int i = 0; i < ct.size(); i++) {
325 infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
326 }
327 }
328 default ->
329 throw new UnsupportedOperationException(arg.type().toString());
330 }
331 }
332 return infos;
333 }
334
335 static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
336 return new GraphProto()
337 .name(name)
338 .forEach(initializers, (g, i) -> g.initializer(i))
339 .forEach(inputs, (g, i) -> g.input(i))
340 .forEach(ops, (g, op) -> g.node(op))
341 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
342 }
343
344 static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
345 int di = functionName.lastIndexOf('.');
346 return new FunctionProto()
347 .domain(functionDomain)
348 .name(functionName)
349 .forEach(inputNames, (f, i) -> f.input(i))
350 .forEach(ops, (g, op) -> g.node(op))
351 .forEach(outputNames, (f, o) -> f.output(o))
352 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
353 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1)));
354 }
355
356 static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
357 return new NodeProto()
358 .domain(domain)
359 .opType(opName)
360 .forEach(inputNames, (n, iName) -> n.input(iName))
361 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
362 .forEach(outputNames, (n, oName) -> n.output(oName));
363 }
364
365 static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
366 int di = opName.lastIndexOf('.');
367 return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes);
368 }
369
370 static ValueInfoProto tensorInfo(String name, int tensorElementType) {
371 return tensorInfo(name, tensorElementType, false);
372 }
373
374 static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
375 var t = new TypeProto.Tensor().elemType(tensorElementType);
376 if (addScalarShape) t.shape(new TensorShapeProto());
377 return new ValueInfoProto()
378 .name(name)
379 .type(new TypeProto().tensorType(t));
380 }
381
382 static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
383 ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor);
384 TensorProto tp = new TensorProto()
385 .name(name)
386 .dataType(tensor.elementType().id)
387 .dims(tensor.shape());
388 return extInfo == null
389 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE))
390 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location()))
391 .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset())))
392 .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length())))
393 .dataLocation(DataLocation.EXTERNAL);
394 }
395
396 static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) {
397 return new TensorProto()
398 .dataType(tensor.elementType().id)
399 .dims(tensor.shape())
400 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
401 }
402
403 static AttributeProto attribute(String name, Object value) {
404 var attr = new AttributeProto().name(name);
405 switch (value) {
406 case Float f -> {
407 attr.type(AttributeType.FLOAT).f(f);
408 }
409 case Long l -> {
410 attr.type(AttributeType.INT).i(l);
411 }
412 case GraphProto g -> {
413 attr.type(AttributeType.GRAPH).g(g.name(name));
414 }
415 case float[] floats -> {
416 attr.type(AttributeType.FLOATS);
417 attr.floats(floats);
418 }
419 case long[] longs -> {
420 attr.type(AttributeType.INTS);
421 attr.ints(longs);
422 }
423 case String s -> {
424 attr.type(AttributeType.STRING);
425 attr.s(s.getBytes());
426 }
427 case Tensor<?> t -> {
428 attr.type(AttributeType.TENSOR);
429 attr.t(tensorProto(t));
430 }
431 default -> {
432 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
433 }
434 }
435 return attr;
436 }
437 }