1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.ArrayList;
 30 import java.util.HashMap;
 31 import java.util.List;
 32 import java.util.function.Function;
 33 import java.util.stream.IntStream;
 34 import jdk.incubator.code.Block;
 35 import jdk.incubator.code.CodeItem;
 36 import jdk.incubator.code.Op;
 37 import jdk.incubator.code.Value;
 38 import jdk.incubator.code.op.CoreOp;
 39 import jdk.incubator.code.type.JavaType;
 40 import jdk.incubator.code.type.TupleType;
 41 import jdk.incubator.code.writer.OpWriter;
 42 import oracle.code.onnx.ir.OnnxOp;
 43 import oracle.code.onnx.ir.OnnxOps;
 44 import oracle.code.onnx.ir.OnnxType;
 45 import oracle.code.onnx.proto.OnnxBuilder.*;
 46 import oracle.code.onnx.proto.OnnxConstants.*;
 47 
 48 public final class OnnxProtoBuilder {
 49 
 50     static final int IR_VERSION = 10;
 51     static final int OPSET_VERSION = 21;
 52 
 53     private static final class Indexer {
 54 
 55         private final Function<CodeItem, String> baseNames;
 56         private final HashMap<String, String> elementsMap;
 57 
 58 
 59         Indexer(Function<CodeItem, String> baseNames) {
 60             this.baseNames = baseNames;
 61             this.elementsMap = new HashMap<>();
 62         }
 63 
 64         private String baseName(Value value, int elementIndex) {
 65             var name = "%" + baseNames.apply(value);
 66             return elementIndex > 0 ? name + '.' + elementIndex : name;
 67         }
 68 
 69         String nameOf(Value value) {
 70             return nameOf(value, 0);
 71         }
 72 
 73         String nameOf(Value tuple, int elementIndex) {
 74             var name = baseName(tuple, elementIndex);
 75             return elementsMap.getOrDefault(name, name);
 76         }
 77 
 78         void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
 79             elementsMap.put(baseName(tupleLoadResult, 0), nameOf(tuple, elementIndex));
 80         }
 81 
 82         void mapTupleElements(Value tuple, List<Value> elements) {
 83             for (int i = 0; i < elements.size(); i++) {
 84                 elementsMap.put(baseName(tuple, i), nameOf(elements.get(i)));
 85             }
 86         }
 87     }
 88 
 89     static byte[] build(String domainName, CoreOp.ModuleOp module, List<oracle.code.onnx.Tensor> initializers) {
 90         var indexer = new Indexer(OpWriter.computeGlobalNames(module));
 91 
 92         var functions = new ArrayList<>(module.functionTable().sequencedValues());
 93         var mainFunc = functions.removeLast();
 94         var mainBlock = mainFunc.body().entryBlock();
 95 
 96         var model = build(
 97                 graph(mainFunc.funcName(), domainName, indexer, mainBlock, initializers, 0),
 98                 List.of(domainName),
 99                 functions.stream().map(f ->
100                         function(domainName,
101                                  f.funcName(),
102                                  f.parameters().stream().map(indexer::nameOf).toList(),
103                                  expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
104                                  nodes(domainName, indexer, f.body().entryBlock().ops()))).toList());
105 
106 //        OnnxProtoPrinter.printModel(model);
107         return model;
108     }
109 
110     // @@@ unchecked constraints:
111     //         tensor FuncOp parameters and single tensor return type
112     //         OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
113     //         entry block only
114     static byte[] build(Block block, List<oracle.code.onnx.Tensor> initializers) {
115         var indexer = new Indexer(OpWriter.computeGlobalNames(block.parentBody().parentOp()));
116         var model = build(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
117 //        OnnxProtoPrinter.printModel(model);
118         return model;
119     }
120 
121     static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
122         return build(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
123     }
124 
125     static byte[] build(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
126         return build(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
127     }
128 
129     static byte[] build(GraphProto graph, List<String> customImportDomains, List<FunctionProto> functions) {
130         return new ModelProto()
131                 .irVersion(IR_VERSION)
132                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
133                 .forEach(customImportDomains, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d)))
134                 .forEach(functions, (m, f) -> m.functions(f))
135                 .graph(graph)
136                 .getBytes();
137     }
138 
139     static List<String> expandTuples(Indexer indexer, List<Value> values) {
140         var names = new ArrayList<String>();
141         expandTuples(indexer, names, values);
142         return names;
143     }
144 
145     static void expandTuples(Indexer indexer, List<String> names, List<Value> values) {
146         for (var v : values) {
147             if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
148                 expandTuples(indexer, names, op.operands());
149             } else if (v.type() instanceof TupleType tt) {
150                 var ct = tt.componentTypes();
151                 for (int i = 0; i < ct.size(); i++) {
152                     names.add(indexer.nameOf(v, i));
153                 }
154             } else {
155                 names.add(indexer.nameOf(v));
156             }
157         }
158     }
159 
160     static GraphProto graph(String graphName, String domainName, Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers, int scalarArgs) {
161         var params = block.parameters();
162         params.forEach(indexer::nameOf);
163         int firstInitializer = params.size() - initializers.size();
164         var args = params.subList(0, firstInitializer);
165         return graph(graphName,
166                 IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.nameOf(params.get(i + firstInitializer)), initializers.get(i))).toList(),
167                 tensorInfos(indexer, args, scalarArgs),
168                 nodes(domainName, indexer, block.ops()),
169                 expandTuples(indexer, block.terminatingOp().operands()));
170     }
171 
172     static List<NodeProto> nodes(String domainName, Indexer indexer, List<Op> ops) {
173         return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
174             switch (op) {
175                 case OnnxOps.If ifOp ->
176                     opNodes.accept(node(
177                             ifOp.opName(),
178                             List.of(indexer.nameOf(ifOp.operands().getFirst())),
179                             IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
180                             java.util.Map.of(
181                                     "then_branch", graph(null, domainName, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
182                                     "else_branch", graph(null, domainName, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
183                 case OnnxOps.Loop loopOp -> {
184                     opNodes.accept(node(loopOp.opName(),
185                             expandTuples(indexer, loopOp.operands()),
186                             IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
187                             java.util.Map.of(
188                                     "body", graph(null, domainName, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
189                 }
190                 case OnnxOp onnxOp ->
191                     opNodes.accept(node(
192                             onnxOp.opName(),
193                             onnxOp.operands().stream().map(indexer::nameOf).toList(),
194                             IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
195                             onnxOp.onnxAttributes()));
196                 case CoreOp.FuncCallOp fco ->
197                     opNodes.accept(node(
198                             domainName,
199                             fco.funcName(),
200                             fco.operands().stream().map(indexer::nameOf).toList(),
201                             expandTuples(indexer, List.of(fco.result())),
202                             java.util.Map.of()));
203                 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
204                 }
205                 case CoreOp.TupleLoadOp tlo ->
206                     indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
207                 case CoreOp.TupleOp to ->
208                     indexer.mapTupleElements(to.result(), to.operands());
209                 case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
210                     if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
211                         indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
212                     } else if (io.invokeDescriptor().name().equals("of")) {
213                         indexer.mapTupleElements(io.result(), io.operands());
214                     } else {
215                         throw new UnsupportedOperationException(op.toText());
216                     }
217                 }
218                 default -> {
219                     throw new UnsupportedOperationException(op.toText());
220                 }
221             }
222         }).toList();
223     }
224 
225     static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
226         var infos = new ArrayList<ValueInfoProto>();
227         for (var arg : args) {
228             switch (arg.type()) {
229                 case OnnxType.TensorType tt ->
230                     infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
231                 case TupleType tt -> {
232                     var ct = tt.componentTypes();
233                     for (int i = 0; i < ct.size(); i++) {
234                         infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
235                     }
236                 }
237                 default ->
238                     throw new UnsupportedOperationException(arg.type().toString());
239             }
240         }
241         return infos;
242     }
243 
244     static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
245         return new GraphProto()
246                 .name(name)
247                 .forEach(initializers, (g, i) -> g.initializer(i))
248                 .forEach(inputs, (g, i) -> g.input(i))
249                 .forEach(ops, (g, op) -> g.node(op))
250                 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
251     }
252 
253     static FunctionProto function(String domain, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
254         return new FunctionProto()
255                 .domain(domain)
256                 .name(functionName)
257                 .forEach(inputNames, (f, i) -> f.input(i))
258                 .forEach(ops, (g, op) -> g.node(op))
259                 .forEach(outputNames, (f, o) -> f.output(o))
260                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION));
261     }
262 
263     static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
264         return new NodeProto()
265                 .domain(domain)
266                 .opType(opName)
267                 .forEach(inputNames, (n, iName) -> n.input(iName))
268                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
269                 .forEach(outputNames, (n, oName) -> n.output(oName));
270     }
271 
272     static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
273         return new NodeProto()
274                 .opType(opName)
275                 .forEach(inputNames, (n, iName) -> n.input(iName))
276                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
277                 .forEach(outputNames, (n, oName) -> n.output(oName));
278     }
279 
280     static ValueInfoProto tensorInfo(String name, int tensorElementType) {
281         return tensorInfo(name, tensorElementType, false);
282     }
283 
284     static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
285         var t = new TypeProto.Tensor().elemType(tensorElementType);
286         if (addScalarShape) t.shape(new TensorShapeProto());
287         return new ValueInfoProto()
288                 .name(name)
289                 .type(new TypeProto().tensorType(t));
290     }
291 
292     static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor) {
293         return new TensorProto()
294                 .name(name)
295                 .dataType(tensor.elementType().id)
296                 .dims(tensor.shape())
297                 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
298     }
299 
300     static AttributeProto attribute(String name, Object value) {
301         var attr = new AttributeProto().name(name);
302         switch (value) {
303             case Float f -> {
304                 attr.type(AttributeType.FLOAT).f(f);
305             }
306             case Long l -> {
307                 attr.type(AttributeType.INT).i(l);
308             }
309             case GraphProto g -> {
310                 attr.type(AttributeType.GRAPH).g(g.name(name));
311             }
312             case float[] floats -> {
313                 attr.type(AttributeType.FLOATS);
314                 attr.floats(floats);
315             }
316             case long[] longs -> {
317                 attr.type(AttributeType.INTS);
318                 attr.ints(longs);
319             }
320             default -> {
321                 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
322             }
323         }
324         return attr;
325     }
326 }