1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.ArrayList;
 30 import java.util.HashMap;
 31 import java.util.List;
 32 import java.util.Map;
 33 import java.util.Optional;
 34 import java.util.SequencedMap;
 35 import java.util.function.Function;
 36 import java.util.stream.IntStream;
 37 import jdk.incubator.code.Block;
 38 import jdk.incubator.code.CodeItem;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Value;
 41 import jdk.incubator.code.op.CoreOp;
 42 import jdk.incubator.code.type.JavaType;
 43 import jdk.incubator.code.type.TupleType;
 44 import jdk.incubator.code.writer.OpWriter;
 45 import oracle.code.onnx.ir.OnnxOp;
 46 import oracle.code.onnx.ir.OnnxOps;
 47 import oracle.code.onnx.ir.OnnxType;
 48 import oracle.code.onnx.proto.OnnxBuilder.*;
 49 import oracle.code.onnx.proto.OnnxConstants.*;
 50 import oracle.code.onnx.proto.OnnxModel;
 51 
 52 public final class OnnxProtoBuilder {
 53 
 54     static final int IR_VERSION = 10;
 55     static final int OPSET_VERSION = 21;
 56 
 57     private static final class Indexer {
 58 
 59         private final Function<CodeItem, String> baseNames;
 60         private final HashMap<String, String> remap;
 61 
 62 
 63         Indexer(Op root, Map<Value, String> explicitNames) {
 64             this.baseNames = OpWriter.computeGlobalNames(root);
 65             this.remap = new HashMap<>();
 66             explicitNames.forEach(this::setName);
 67         }
 68 
 69         void setName(Value val, String name) {
 70             remap.put(baseName(val), name);
 71             if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) {
 72                 Value tr = tlo.operands().getFirst();
 73                 remap.put(baseName(tr, tlo.index()), name);
 74                 if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) {
 75                     setName(to.operands().get(tlo.index()), name);
 76                 }
 77             }
 78         }
 79 
 80         private String baseName(Value value) {
 81             return "%" + baseNames.apply(value);
 82         }
 83 
 84         private String baseName(Value value, int elementIndex) {
 85             var name = baseName(value);
 86             return elementIndex > 0 ? name + '.' + elementIndex : name;
 87         }
 88 
 89         String nameOf(Value value) {
 90             var name = baseName(value);
 91             return remap.getOrDefault(name, name);
 92         }
 93 
 94         String nameOf(Value tuple, int elementIndex) {
 95             var name = baseName(tuple, elementIndex);
 96             return remap.getOrDefault(name, name);
 97         }
 98 
 99         void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
100             remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex));
101         }
102 
103         void mapTupleElements(Value tuple, List<Value> elements) {
104             for (int i = 0; i < elements.size(); i++) {
105                 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i)));
106             }
107         }
108     }
109 
110     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) {
111         return buildModel(domain, module, initializers, Map.of(), _ -> null);
112     }
113 
114     public record ExternalTensorDataInfo(String location, long offset, long length) {
115     }
116 
117     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
118         var indexer = new Indexer(module, explicitValueNames);
119 
120         var functions = new ArrayList<>(module.functionTable().sequencedValues());
121         var imports = new ArrayList<String>();
122         if (functions.size() > 1) imports.add(domain); // self domain import if additional functions
123         for (var f : functions) {
124             for (var op : f.body().entryBlock().ops()) { // auto import of op domains
125                 if (op instanceof OnnxOp) {
126                     int di = op.opName().lastIndexOf('.');
127                     if (di > 0) {
128                         String dn = op.opName().substring(0, di);
129                         if (!imports.contains(dn)) imports.add(dn);
130                     }
131                 }
132             }
133         }
134         var mainFunc = functions.removeLast();
135         var mainBlock = mainFunc.body().entryBlock();
136 
137         var model = buildModel(
138                 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer),
139                 imports,
140                 functions.stream().map(f ->
141                         function(domain, imports, f.funcName(),
142                                  expandTuples(indexer, f.parameters()),
143                                  expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
144                                  nodes(domain, indexer, f.body().entryBlock().ops()))).toList());
145 
146 //        System.out.println(OnnxModel.readFrom(model).toText());
147         return model;
148     }
149 
150     // @@@ unchecked constraints:
151     //         tensor FuncOp parameters and single tensor return type
152     //         OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
153     //         entry block only
154     static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) {
155         var indexer = new Indexer(block.parentBody().parentOp(), Map.of());
156         var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
157 //        System.out.println(OnnxModel.readFrom(model).toText());
158         return model;
159     }
160 
161     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
162         return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
163     }
164 
165     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
166         return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
167     }
168 
169     static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) {
170         return new ModelProto()
171                 .irVersion(IR_VERSION)
172                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
173                 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1)))
174                 .forEach(functions, (m, f) -> m.functions(f))
175                 .graph(graph)
176                 .getBytes();
177     }
178 
179     static List<String> expandTuples(Indexer indexer, List<? extends Value> values) {
180         var names = new ArrayList<String>();
181         expandTuples(indexer, names, values);
182         return names;
183     }
184 
185     static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) {
186         for (var v : values) {
187             if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
188                 expandTuples(indexer, names, op.operands());
189             } else if (v.type() instanceof TupleType tt) {
190                 var ct = tt.componentTypes();
191                 for (int i = 0; i < ct.size(); i++) {
192                     names.add(indexer.nameOf(v, i));
193                 }
194             } else {
195                 names.add(indexer.nameOf(v));
196             }
197         }
198     }
199 
200     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) {
201         return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null);
202     }
203 
204     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
205         var params = block.parameters();
206         params.forEach(indexer::nameOf);
207         int firstInitializer = params.size() - initializers.size();
208         var args = params.subList(0, firstInitializer);
209         return graph(graphName,
210                 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> {
211                     Object val = initializers.get(i);
212                     if (val instanceof Record) {
213                         var rcs = val.getClass().getRecordComponents();
214                         for (int rci = 0; rci < rcs.length; rci++) try {
215                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer));
216                         } catch (ReflectiveOperationException e) {
217                             throw new IllegalArgumentException(e);
218                         }
219                     } else {
220                         tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer));
221                     }
222                 }).toList(),
223                 tensorInfos(indexer, args, scalarArgs),
224                 nodes(domain, indexer, block.ops()),
225                 expandTuples(indexer, block.terminatingOp().operands()));
226     }
227 
228     static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) {
229         List<String> inputNames = inputs.sequencedValues().stream()
230                 .<String>mapMulti((v, dump) -> {
231                     switch (v) {
232                         case Value val -> dump.accept(indexer.nameOf(val));
233                         case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val));
234                         case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val)));
235                         default -> dump.accept(""); // empty names for unused optional inputs
236                     }
237                 }).toList();
238         // trim trailing empty names
239         return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed();
240     }
241 
242     static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) {
243         return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
244             switch (op) {
245                 case OnnxOps.If ifOp ->
246                     opNodes.accept(node(
247                             ifOp.opName(),
248                             List.of(indexer.nameOf(ifOp.operands().getFirst())),
249                             IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
250                             java.util.Map.of(
251                                     "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
252                                     "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
253                 case OnnxOps.Loop loopOp -> {
254                     opNodes.accept(node(loopOp.opName(),
255                             expandTuples(indexer, loopOp.operands()),
256                             IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
257                             java.util.Map.of(
258                                     "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
259                 }
260                 case OnnxOp onnxOp ->
261                     opNodes.accept(node(
262                             onnxOp.opName(),
263                             opInputNames(indexer, onnxOp.onnxInputs()),
264                             IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
265                             onnxOp.onnxAttributes()));
266                 case CoreOp.FuncCallOp fco ->
267                     opNodes.accept(node(
268                             domain,
269                             fco.funcName(),
270                             expandTuples(indexer, fco.operands()),
271                             expandTuples(indexer, List.of(fco.result())),
272                             java.util.Map.of()));
273                 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
274                 }
275                 case CoreOp.TupleLoadOp tlo ->
276                     indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
277                 case CoreOp.TupleOp to ->
278                     indexer.mapTupleElements(to.result(), to.operands());
279                 case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
280                     if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
281                         indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
282                     } else if (io.invokeDescriptor().name().equals("of")) {
283                         indexer.mapTupleElements(io.result(), io.operands());
284                     } else {
285                         throw new UnsupportedOperationException(op.toText());
286                     }
287                 }
288                 default -> {
289                     throw new UnsupportedOperationException(op.toText());
290                 }
291             }
292         }).toList();
293     }
294 
295     static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
296         var infos = new ArrayList<ValueInfoProto>();
297         for (var arg : args) {
298             switch (arg.type()) {
299                 case OnnxType.TensorType tt ->
300                     infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
301                 case TupleType tt -> {
302                     var ct = tt.componentTypes();
303                     for (int i = 0; i < ct.size(); i++) {
304                         infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
305                     }
306                 }
307                 default ->
308                     throw new UnsupportedOperationException(arg.type().toString());
309             }
310         }
311         return infos;
312     }
313 
314     static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
315         return new GraphProto()
316                 .name(name)
317                 .forEach(initializers, (g, i) -> g.initializer(i))
318                 .forEach(inputs, (g, i) -> g.input(i))
319                 .forEach(ops, (g, op) -> g.node(op))
320                 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
321     }
322 
323     static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
324         int di = functionName.lastIndexOf('.');
325         return new FunctionProto()
326                 .domain(functionDomain)
327                 .name(functionName)
328                 .forEach(inputNames, (f, i) -> f.input(i))
329                 .forEach(ops, (g, op) -> g.node(op))
330                 .forEach(outputNames, (f, o) -> f.output(o))
331                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
332                 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1)));
333     }
334 
335     static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
336         return new NodeProto()
337                 .domain(domain)
338                 .opType(opName)
339                 .forEach(inputNames, (n, iName) -> n.input(iName))
340                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
341                 .forEach(outputNames, (n, oName) -> n.output(oName));
342     }
343 
344     static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
345         int di = opName.lastIndexOf('.');
346         return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes);
347     }
348 
349     static ValueInfoProto tensorInfo(String name, int tensorElementType) {
350         return tensorInfo(name, tensorElementType, false);
351     }
352 
353     static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
354         var t = new TypeProto.Tensor().elemType(tensorElementType);
355         if (addScalarShape) t.shape(new TensorShapeProto());
356         return new ValueInfoProto()
357                 .name(name)
358                 .type(new TypeProto().tensorType(t));
359     }
360 
361     static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
362         ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor);
363         TensorProto tp = new TensorProto()
364                 .name(name)
365                 .dataType(tensor.elementType().id)
366                 .dims(tensor.shape());
367         return extInfo == null
368                 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE))
369                 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location()))
370                     .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset())))
371                     .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length())))
372                     .dataLocation(DataLocation.EXTERNAL);
373     }
374 
375     static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) {
376         return new TensorProto()
377                 .dataType(tensor.elementType().id)
378                 .dims(tensor.shape())
379                 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
380     }
381 
382     static AttributeProto attribute(String name, Object value) {
383         var attr = new AttributeProto().name(name);
384         switch (value) {
385             case Float f -> {
386                 attr.type(AttributeType.FLOAT).f(f);
387             }
388             case Long l -> {
389                 attr.type(AttributeType.INT).i(l);
390             }
391             case GraphProto g -> {
392                 attr.type(AttributeType.GRAPH).g(g.name(name));
393             }
394             case float[] floats -> {
395                 attr.type(AttributeType.FLOATS);
396                 attr.floats(floats);
397             }
398             case long[] longs -> {
399                 attr.type(AttributeType.INTS);
400                 attr.ints(longs);
401             }
402             case Tensor<?> t -> {
403                 attr.type(AttributeType.TENSOR);
404                 attr.t(tensorProto(t));
405             }
406             default -> {
407                 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
408             }
409         }
410         return attr;
411     }
412 }