1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.ArrayList;
 30 import java.util.HashMap;
 31 import java.util.List;
 32 import java.util.Map;
 33 import java.util.Optional;
 34 import java.util.SequencedMap;
 35 import java.util.function.Function;
 36 import java.util.stream.IntStream;
 37 import jdk.incubator.code.Block;
 38 import jdk.incubator.code.CodeItem;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Value;
 41 import jdk.incubator.code.dialect.core.CoreOp;
 42 import jdk.incubator.code.dialect.core.TupleType;
 43 import jdk.incubator.code.dialect.java.JavaOp;
 44 import jdk.incubator.code.dialect.java.JavaType;
 45 import jdk.incubator.code.extern.OpWriter;
 46 import oracle.code.onnx.ir.OnnxOp;
 47 import oracle.code.onnx.ir.OnnxOps;
 48 import oracle.code.onnx.ir.OnnxType;
 49 import oracle.code.onnx.proto.OnnxBuilder.*;
 50 import oracle.code.onnx.proto.OnnxConstants.*;
 51 
 52 public final class OnnxProtoBuilder {
 53 
 54     static final int IR_VERSION = 10;
 55     static final int OPSET_VERSION = 21;
 56 
 57     private static final class Indexer {
 58 
 59         private final Function<CodeItem, String> baseNames;
 60         private final HashMap<String, String> remap;
 61 
 62 
 63         Indexer(Op root, Map<Value, String> explicitNames) {
 64             this.baseNames = OpWriter.computeGlobalNames(root);
 65             this.remap = new HashMap<>();
 66             explicitNames.forEach(this::setName);
 67         }
 68 
 69         void setName(Value val, String name) {
 70             switch (val) {
 71                 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> {
 72                     remap.put(baseName(val), name);
 73                     for (int i = 0; i < to.operands().size(); i++) {
 74                         setName(to.operands().get(i), name + "." + i);
 75                     }
 76                 }
 77                 case Block.Parameter bp when val.type() instanceof TupleType tt -> {
 78                     for (int i = 0; i < tt.componentTypes().size(); i++) {
 79                         remap.put(baseName(val, i), name +"." + i);
 80                     }
 81                 }
 82                 default -> {
 83                     remap.put(baseName(val), name);
 84                     if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) {
 85                         Value tr = tlo.operands().getFirst();
 86                         remap.put(baseName(tr, tlo.index()), name);
 87                         if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) {
 88                             setName(to.operands().get(tlo.index()), name);
 89                         }
 90                     }
 91                 }
 92             }
 93         }
 94 
 95         private String baseName(Value value) {
 96             return "%" + baseNames.apply(value);
 97         }
 98 
 99         private String baseName(Value value, int elementIndex) {
100             var name = baseName(value);
101             return elementIndex > 0 ? name + '.' + elementIndex : name;
102         }
103 
104         String nameOf(Value value) {
105             var name = baseName(value);
106             return remap.getOrDefault(name, name);
107         }
108 
109         String nameOf(Value tuple, int elementIndex) {
110             var name = baseName(tuple, elementIndex);
111             return remap.getOrDefault(name, name);
112         }
113 
114         void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
115             remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex));
116         }
117 
118         void mapTupleElements(Value tuple, List<Value> elements) {
119             for (int i = 0; i < elements.size(); i++) {
120                 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i)));
121             }
122         }
123     }
124 
125     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) {
126         return buildModel(domain, module, initializers, Map.of(), _ -> null);
127     }
128 
129     public record ExternalTensorDataInfo(String location, long offset, long length) {
130     }
131 
132     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
133         var indexer = new Indexer(module, explicitValueNames);
134 
135         var functions = new ArrayList<>(module.functionTable().sequencedValues());
136         var imports = new ArrayList<String>();
137         if (functions.size() > 1) imports.add(domain); // self domain import if additional functions
138         for (var f : functions) {
139             for (var op : f.body().entryBlock().ops()) { // auto import of op domains
140                 if (op instanceof OnnxOp) {
141                     int di = op.opName().lastIndexOf('.');
142                     if (di > 0) {
143                         String dn = op.opName().substring(0, di);
144                         if (!imports.contains(dn)) imports.add(dn);
145                     }
146                 }
147             }
148         }
149         var mainFunc = functions.removeLast();
150         var mainBlock = mainFunc.body().entryBlock();
151 
152         var model = buildModel(
153                 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer),
154                 imports,
155                 functions.stream().map(f ->
156                         function(domain, imports, f.funcName(),
157                                  expandTuples(indexer, f.parameters()),
158                                  expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
159                                  nodes(domain, indexer, f.body().entryBlock().ops()))).toList());
160 
161 //        System.out.println(OnnxModel.readFrom(model).toText());
162         return model;
163     }
164 
165     // @@@ unchecked constraints:
166     //         tensor FuncOp parameters and single tensor return type
167     //         OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
168     //         entry block only
169     static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) {
170         var indexer = new Indexer(block.parentBody().parentOp(), Map.of());
171         var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
172 //        System.out.println(OnnxModel.readFrom(model).toText());
173         return model;
174     }
175 
176     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
177         return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
178     }
179 
180     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
181         return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
182     }
183 
184     static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) {
185         return new ModelProto()
186                 .irVersion(IR_VERSION)
187                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
188                 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1)))
189                 .forEach(functions, (m, f) -> m.functions(f))
190                 .graph(graph)
191                 .getBytes();
192     }
193 
194     static List<String> expandTuples(Indexer indexer, List<? extends Value> values) {
195         var names = new ArrayList<String>();
196         expandTuples(indexer, names, values);
197         return names;
198     }
199 
200     static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) {
201         for (var v : values) {
202             if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
203                 expandTuples(indexer, names, op.operands());
204             } else if (v.type() instanceof TupleType tt) {
205                 var ct = tt.componentTypes();
206                 for (int i = 0; i < ct.size(); i++) {
207                     names.add(indexer.nameOf(v, i));
208                 }
209             } else {
210                 names.add(indexer.nameOf(v));
211             }
212         }
213     }
214 
215     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) {
216         return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null);
217     }
218 
219     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
220         var params = block.parameters();
221         params.forEach(indexer::nameOf);
222         int firstInitializer = params.size() - initializers.size();
223         var args = params.subList(0, firstInitializer);
224         return graph(graphName,
225                 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> {
226                     Object val = initializers.get(i);
227                     if (val instanceof Record) {
228                         var rcs = val.getClass().getRecordComponents();
229                         for (int rci = 0; rci < rcs.length; rci++) try {
230                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer));
231                         } catch (ReflectiveOperationException e) {
232                             throw new IllegalArgumentException(e);
233                         }
234                     } else if (val instanceof Tensor[] tarr) {
235                         for (int tai = 0; tai < tarr.length; tai++) {
236                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer));
237                         }
238                     } else {
239                         tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer));
240                     }
241                 }).toList(),
242                 tensorInfos(indexer, args, scalarArgs),
243                 nodes(domain, indexer, block.ops()),
244                 expandTuples(indexer, block.terminatingOp().operands()));
245     }
246 
247     static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) {
248         List<String> inputNames = inputs.sequencedValues().stream()
249                 .<String>mapMulti((v, dump) -> {
250                     switch (v) {
251                         case Value val -> dump.accept(indexer.nameOf(val));
252                         case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val));
253                         case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val)));
254                         default -> dump.accept(""); // empty names for unused optional inputs
255                     }
256                 }).toList();
257         // trim trailing empty names
258         return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed();
259     }
260 
261     static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) {
262         return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
263             switch (op) {
264                 case OnnxOps.If ifOp ->
265                     opNodes.accept(node(
266                             ifOp.opName(),
267                             List.of(indexer.nameOf(ifOp.operands().getFirst())),
268                             IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
269                             java.util.Map.of(
270                                     "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
271                                     "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
272                 case OnnxOps.Loop loopOp -> {
273                     opNodes.accept(node(loopOp.opName(),
274                             expandTuples(indexer, loopOp.operands()),
275                             IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
276                             java.util.Map.of(
277                                     "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
278                 }
279                 case OnnxOp onnxOp ->
280                     opNodes.accept(node(
281                             onnxOp.opName(),
282                             opInputNames(indexer, onnxOp.onnxInputs()),
283                             IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
284                             onnxOp.onnxAttributes()));
285                 case CoreOp.FuncCallOp fco ->
286                     opNodes.accept(node(
287                             domain,
288                             fco.funcName(),
289                             expandTuples(indexer, fco.operands()),
290                             expandTuples(indexer, List.of(fco.result())),
291                             java.util.Map.of()));
292                 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
293                 }
294                 case CoreOp.TupleLoadOp tlo ->
295                     indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
296                 case CoreOp.TupleOp to ->
297                     indexer.mapTupleElements(to.result(), to.operands());
298                 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
299                     if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
300                         indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
301                     } else if (io.invokeDescriptor().name().equals("of")) {
302                         indexer.mapTupleElements(io.result(), io.operands());
303                     } else {
304                         throw new UnsupportedOperationException(op.toText());
305                     }
306                 }
307                 default -> {
308                     throw new UnsupportedOperationException(op.toText());
309                 }
310             }
311         }).toList();
312     }
313 
314     static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
315         var infos = new ArrayList<ValueInfoProto>();
316         for (var arg : args) {
317             switch (arg.type()) {
318                 case OnnxType.TensorType tt ->
319                     infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
320                 case TupleType tt -> {
321                     var ct = tt.componentTypes();
322                     for (int i = 0; i < ct.size(); i++) {
323                         infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
324                     }
325                 }
326                 default ->
327                     throw new UnsupportedOperationException(arg.type().toString());
328             }
329         }
330         return infos;
331     }
332 
333     static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
334         return new GraphProto()
335                 .name(name)
336                 .forEach(initializers, (g, i) -> g.initializer(i))
337                 .forEach(inputs, (g, i) -> g.input(i))
338                 .forEach(ops, (g, op) -> g.node(op))
339                 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
340     }
341 
342     static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
343         int di = functionName.lastIndexOf('.');
344         return new FunctionProto()
345                 .domain(functionDomain)
346                 .name(functionName)
347                 .forEach(inputNames, (f, i) -> f.input(i))
348                 .forEach(ops, (g, op) -> g.node(op))
349                 .forEach(outputNames, (f, o) -> f.output(o))
350                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
351                 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1)));
352     }
353 
354     static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
355         return new NodeProto()
356                 .domain(domain)
357                 .opType(opName)
358                 .forEach(inputNames, (n, iName) -> n.input(iName))
359                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
360                 .forEach(outputNames, (n, oName) -> n.output(oName));
361     }
362 
363     static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
364         int di = opName.lastIndexOf('.');
365         return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes);
366     }
367 
368     static ValueInfoProto tensorInfo(String name, int tensorElementType) {
369         return tensorInfo(name, tensorElementType, false);
370     }
371 
372     static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
373         var t = new TypeProto.Tensor().elemType(tensorElementType);
374         if (addScalarShape) t.shape(new TensorShapeProto());
375         return new ValueInfoProto()
376                 .name(name)
377                 .type(new TypeProto().tensorType(t));
378     }
379 
380     static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
381         ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor);
382         TensorProto tp = new TensorProto()
383                 .name(name)
384                 .dataType(tensor.elementType().id)
385                 .dims(tensor.shape());
386         return extInfo == null
387                 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE))
388                 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location()))
389                     .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset())))
390                     .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length())))
391                     .dataLocation(DataLocation.EXTERNAL);
392     }
393 
394     static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) {
395         return new TensorProto()
396                 .dataType(tensor.elementType().id)
397                 .dims(tensor.shape())
398                 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
399     }
400 
401     static AttributeProto attribute(String name, Object value) {
402         var attr = new AttributeProto().name(name);
403         switch (value) {
404             case Float f -> {
405                 attr.type(AttributeType.FLOAT).f(f);
406             }
407             case Long l -> {
408                 attr.type(AttributeType.INT).i(l);
409             }
410             case GraphProto g -> {
411                 attr.type(AttributeType.GRAPH).g(g.name(name));
412             }
413             case float[] floats -> {
414                 attr.type(AttributeType.FLOATS);
415                 attr.floats(floats);
416             }
417             case long[] longs -> {
418                 attr.type(AttributeType.INTS);
419                 attr.ints(longs);
420             }
421             case Tensor<?> t -> {
422                 attr.type(AttributeType.TENSOR);
423                 attr.t(tensorProto(t));
424             }
425             default -> {
426                 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
427             }
428         }
429         return attr;
430     }
431 }