1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.ArrayList;
 30 import java.util.HashMap;
 31 import java.util.List;
 32 import java.util.Map;
 33 import java.util.Optional;
 34 import java.util.SequencedMap;
 35 import java.util.function.Function;
 36 import java.util.stream.IntStream;
 37 import jdk.incubator.code.Block;
 38 import jdk.incubator.code.CodeItem;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Value;
 41 import jdk.incubator.code.dialect.core.CoreOp;
 42 import jdk.incubator.code.dialect.core.TupleType;
 43 import jdk.incubator.code.dialect.java.JavaOp;
 44 import jdk.incubator.code.dialect.java.JavaType;
 45 import jdk.incubator.code.extern.OpWriter;
 46 import oracle.code.onnx.ir.OnnxOp;
 47 import oracle.code.onnx.ir.OnnxOps;
 48 import oracle.code.onnx.ir.OnnxType;
 49 import oracle.code.onnx.proto.OnnxBuilder.*;
 50 import oracle.code.onnx.proto.OnnxConstants.*;
 51 
 52 public final class OnnxProtoBuilder {
 53 
 54     static final int IR_VERSION = 10;
 55     static final int OPSET_VERSION = 21;
 56 
 57     private static final class Indexer {
 58 
 59         private final Function<CodeItem, String> baseNames;
 60         private final HashMap<String, String> remap;
 61 
 62 
 63         Indexer(Op root, Map<Value, String> explicitNames) {
 64             this.baseNames = OpWriter.computeGlobalNames(root);
 65             this.remap = new HashMap<>();
 66             explicitNames.forEach(this::setName);
 67         }
 68 
 69         void setName(Value val, String name) {
 70             switch (val) {
 71                 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> {
 72                     remap.put(baseName(val), name);
 73                     for (int i = 0; i < to.operands().size(); i++) {
 74                         setName(to.operands().get(i), name + "." + i);
 75                     }
 76                 }
 77                 case Block.Parameter bp when val.type() instanceof TupleType tt -> {
 78                     for (int i = 0; i < tt.componentTypes().size(); i++) {
 79                         remap.put(baseName(val, i), name +"." + i);
 80                     }
 81                 }
 82                 default -> {
 83                     remap.put(baseName(val), name);
 84                     if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) {
 85                         Value tr = tlo.operands().getFirst();
 86                         remap.put(baseName(tr, tlo.index()), name);
 87                         if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) {
 88                             setName(to.operands().get(tlo.index()), name);
 89                         }
 90                     }
 91                 }
 92             }
 93         }
 94 
 95         private String baseName(Value value) {
 96             return "%" + baseNames.apply(value);
 97         }
 98 
 99         private String baseName(Value value, int elementIndex) {
100             var name = baseName(value);
101             return elementIndex > 0 ? name + '.' + elementIndex : name;
102         }
103 
104         String nameOf(Value value) {
105             var name = baseName(value);
106             return remap.getOrDefault(name, name);
107         }
108 
109         String nameOf(Value tuple, int elementIndex) {
110             var name = baseName(tuple, elementIndex);
111             return remap.getOrDefault(name, name);
112         }
113 
114         void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
115             remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex));
116         }
117 
118         void mapTupleElements(Value tuple, List<Value> elements) {
119             for (int i = 0; i < elements.size(); i++) {
120                 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i)));
121             }
122         }
123     }
124 
125     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) {
126         return buildModel(domain, module, initializers, Map.of(), _ -> null);
127     }
128 
129     public record ExternalTensorDataInfo(String location, long offset, long length) {
130     }
131 
132     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
133         var indexer = new Indexer(module, explicitValueNames);
134 
135         var functions = new ArrayList<>(module.functionTable().sequencedValues());
136         var imports = new ArrayList<String>();
137         if (functions.size() > 1) imports.add(domain); // self domain import if additional functions
138         for (var f : functions) {
139             for (var op : f.body().entryBlock().ops()) { // auto import of op domains
140                 if (op instanceof OnnxOp oop) {
141                     String name = oop.schema().name();
142                     int di = name.lastIndexOf('.');
143                     if (di > 0) {
144                         String dn = name.substring(0, di);
145                         if (!imports.contains(dn)) imports.add(dn);
146                     }
147                 }
148             }
149         }
150         var mainFunc = functions.removeLast();
151         var mainBlock = mainFunc.body().entryBlock();
152 
153         var model = buildModel(
154                 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer),
155                 imports,
156                 functions.stream().map(f ->
157                         function(domain, imports, f.funcName(),
158                                  expandTuples(indexer, f.parameters()),
159                                  expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
160                                  nodes(domain, indexer, f.body().entryBlock().ops()))).toList());
161 
162 //        System.out.println(OnnxModel.readFrom(model).toText());
163         return model;
164     }
165 
166     // @@@ unchecked constraints:
167     //         tensor FuncOp parameters and single tensor return type
168     //         OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
169     //         entry block only
170     static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) {
171         var indexer = new Indexer(block.ancestorOp(), Map.of());
172         var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
173 //        System.out.println(OnnxModel.readFrom(model).toText());
174         return model;
175     }
176 
177     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
178         return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
179     }
180 
181     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
182         return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
183     }
184 
185     static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) {
186         return new ModelProto()
187                 .irVersion(IR_VERSION)
188                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
189                 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1)))
190                 .forEach(functions, (m, f) -> m.functions(f))
191                 .graph(graph)
192                 .getBytes();
193     }
194 
195     static List<String> expandTuples(Indexer indexer, List<? extends Value> values) {
196         var names = new ArrayList<String>();
197         expandTuples(indexer, names, values);
198         return names;
199     }
200 
201     static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) {
202         for (var v : values) {
203             if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
204                 expandTuples(indexer, names, op.operands());
205             } else if (v.type() instanceof TupleType tt) {
206                 var ct = tt.componentTypes();
207                 for (int i = 0; i < ct.size(); i++) {
208                     names.add(indexer.nameOf(v, i));
209                 }
210             } else {
211                 names.add(indexer.nameOf(v));
212             }
213         }
214     }
215 
216     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) {
217         return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null);
218     }
219 
220     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
221         var params = block.parameters();
222         params.forEach(indexer::nameOf);
223         int firstInitializer = params.size() - initializers.size();
224         var args = params.subList(0, firstInitializer);
225         return graph(graphName,
226                 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> {
227                     Object val = initializers.get(i);
228                     if (val instanceof Record) {
229                         var rcs = val.getClass().getRecordComponents();
230                         for (int rci = 0; rci < rcs.length; rci++) try {
231                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer));
232                         } catch (ReflectiveOperationException e) {
233                             throw new IllegalArgumentException(e);
234                         }
235                     } else if (val instanceof Tensor[] tarr) {
236                         for (int tai = 0; tai < tarr.length; tai++) {
237                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer));
238                         }
239                     } else {
240                         tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer));
241                     }
242                 }).toList(),
243                 tensorInfos(indexer, args, scalarArgs),
244                 nodes(domain, indexer, block.ops()),
245                 expandTuples(indexer, block.terminatingOp().operands()));
246     }
247 
248     static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) {
249         List<String> inputNames = inputs.sequencedValues().stream()
250                 .<String>mapMulti((v, dump) -> {
251                     switch (v) {
252                         case Value val -> dump.accept(indexer.nameOf(val));
253                         case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val));
254                         case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val)));
255                         default -> dump.accept(""); // empty names for unused optional inputs
256                     }
257                 }).toList();
258         // trim trailing empty names
259         return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed();
260     }
261 
262     static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) {
263         return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
264             switch (op) {
265                 case OnnxOps.If ifOp ->
266                     opNodes.accept(node(
267                             ifOp.schema().name(),
268                             List.of(indexer.nameOf(ifOp.operands().getFirst())),
269                             IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
270                             java.util.Map.of(
271                                     "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
272                                     "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
273                 case OnnxOps.Loop loopOp -> {
274                     opNodes.accept(node(loopOp.schema().name(),
275                             expandTuples(indexer, loopOp.operands()),
276                             IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
277                             java.util.Map.of(
278                                     "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
279                 }
280                 case OnnxOp onnxOp ->
281                     opNodes.accept(node(
282                             onnxOp.schema().name(),
283                             opInputNames(indexer, onnxOp.onnxInputs()),
284                             IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
285                             onnxOp.onnxAttributes()));
286                 case CoreOp.FuncCallOp fco ->
287                     opNodes.accept(node(
288                             domain,
289                             fco.funcName(),
290                             expandTuples(indexer, fco.operands()),
291                             expandTuples(indexer, List.of(fco.result())),
292                             java.util.Map.of()));
293                 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
294                 }
295                 case CoreOp.TupleLoadOp tlo ->
296                     indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
297                 case CoreOp.TupleOp to ->
298                     indexer.mapTupleElements(to.result(), to.operands());
299                 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
300                     if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
301                         indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
302                     } else if (io.invokeDescriptor().name().equals("of")) {
303                         indexer.mapTupleElements(io.result(), io.operands());
304                     } else {
305                         throw new UnsupportedOperationException(op.toText());
306                     }
307                 }
308                 default -> {
309                     throw new UnsupportedOperationException(op.toText());
310                 }
311             }
312         }).toList();
313     }
314 
315     static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
316         var infos = new ArrayList<ValueInfoProto>();
317         for (var arg : args) {
318             switch (arg.type()) {
319                 case OnnxType.TensorType tt ->
320                     infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
321                 case TupleType tt -> {
322                     var ct = tt.componentTypes();
323                     for (int i = 0; i < ct.size(); i++) {
324                         infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
325                     }
326                 }
327                 default ->
328                     throw new UnsupportedOperationException(arg.type().toString());
329             }
330         }
331         return infos;
332     }
333 
334     static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
335         return new GraphProto()
336                 .name(name)
337                 .forEach(initializers, (g, i) -> g.initializer(i))
338                 .forEach(inputs, (g, i) -> g.input(i))
339                 .forEach(ops, (g, op) -> g.node(op))
340                 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
341     }
342 
343     static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
344         int di = functionName.lastIndexOf('.');
345         return new FunctionProto()
346                 .domain(functionDomain)
347                 .name(functionName)
348                 .forEach(inputNames, (f, i) -> f.input(i))
349                 .forEach(ops, (g, op) -> g.node(op))
350                 .forEach(outputNames, (f, o) -> f.output(o))
351                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
352                 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1)));
353     }
354 
355     static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
356         return new NodeProto()
357                 .domain(domain)
358                 .opType(opName)
359                 .forEach(inputNames, (n, iName) -> n.input(iName))
360                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
361                 .forEach(outputNames, (n, oName) -> n.output(oName));
362     }
363 
364     static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
365         int di = opName.lastIndexOf('.');
366         return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes);
367     }
368 
369     static ValueInfoProto tensorInfo(String name, int tensorElementType) {
370         return tensorInfo(name, tensorElementType, false);
371     }
372 
373     static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
374         var t = new TypeProto.Tensor().elemType(tensorElementType);
375         if (addScalarShape) t.shape(new TensorShapeProto());
376         return new ValueInfoProto()
377                 .name(name)
378                 .type(new TypeProto().tensorType(t));
379     }
380 
381     static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
382         ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor);
383         TensorProto tp = new TensorProto()
384                 .name(name)
385                 .dataType(tensor.elementType().id)
386                 .dims(tensor.shape());
387         return extInfo == null
388                 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE))
389                 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location()))
390                     .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset())))
391                     .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length())))
392                     .dataLocation(DataLocation.EXTERNAL);
393     }
394 
395     static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) {
396         return new TensorProto()
397                 .dataType(tensor.elementType().id)
398                 .dims(tensor.shape())
399                 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
400     }
401 
402     static AttributeProto attribute(String name, Object value) {
403         var attr = new AttributeProto().name(name);
404         switch (value) {
405             case Float f -> {
406                 attr.type(AttributeType.FLOAT).f(f);
407             }
408             case Long l -> {
409                 attr.type(AttributeType.INT).i(l);
410             }
411             case GraphProto g -> {
412                 attr.type(AttributeType.GRAPH).g(g.name(name));
413             }
414             case float[] floats -> {
415                 attr.type(AttributeType.FLOATS);
416                 attr.floats(floats);
417             }
418             case long[] longs -> {
419                 attr.type(AttributeType.INTS);
420                 attr.ints(longs);
421             }
422             case Tensor<?> t -> {
423                 attr.type(AttributeType.TENSOR);
424                 attr.t(tensorProto(t));
425             }
426             default -> {
427                 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
428             }
429         }
430         return attr;
431     }
432 }