1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.ArrayList;
 30 import java.util.HashMap;
 31 import java.util.List;
 32 import java.util.Map;
 33 import java.util.SequencedMap;
 34 import java.util.function.Function;
 35 import java.util.stream.IntStream;
 36 import jdk.incubator.code.Block;
 37 import jdk.incubator.code.CodeItem;
 38 import jdk.incubator.code.Op;
 39 import jdk.incubator.code.Value;
 40 import jdk.incubator.code.dialect.core.CoreOp;
 41 import jdk.incubator.code.dialect.core.TupleType;
 42 import jdk.incubator.code.dialect.java.JavaOp;
 43 import jdk.incubator.code.dialect.java.JavaType;
 44 import jdk.incubator.code.extern.OpWriter;
 45 import oracle.code.onnx.ir.OnnxOp;
 46 import oracle.code.onnx.ir.OnnxOps;
 47 import oracle.code.onnx.ir.OnnxType;
 48 import oracle.code.onnx.proto.OnnxBuilder.*;
 49 import oracle.code.onnx.proto.OnnxConstants.*;
 50 
 51 public final class OnnxProtoBuilder {
 52 
 53     static final int IR_VERSION = 10;
 54     static final int OPSET_VERSION = 21;
 55 
 56     private static final class Indexer {
 57 
 58         private final Function<CodeItem, String> baseNames;
 59         private final HashMap<String, String> remap;
 60 
 61 
 62         Indexer(Op root, Map<Value, String> explicitNames) {
 63             this.baseNames = OpWriter.computeGlobalNames(root);
 64             this.remap = new HashMap<>();
 65             explicitNames.forEach(this::setName);
 66         }
 67 
 68         void setName(Value val, String name) {
 69             switch (val) {
 70                 case Op.Result or when or.op() instanceof CoreOp.TupleOp to -> {
 71                     remap.put(baseName(val), name);
 72                     for (int i = 0; i < to.operands().size(); i++) {
 73                         setName(to.operands().get(i), name + "." + i);
 74                     }
 75                 }
 76                 case Block.Parameter bp when val.type() instanceof TupleType tt -> {
 77                     for (int i = 0; i < tt.componentTypes().size(); i++) {
 78                         remap.put(baseName(val, i), name +"." + i);
 79                     }
 80                 }
 81                 default -> {
 82                     remap.put(baseName(val), name);
 83                     if (val instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp tlo) {
 84                         Value tr = tlo.operands().getFirst();
 85                         remap.put(baseName(tr, tlo.index()), name);
 86                         if (tr instanceof Op.Result tor && tor.op() instanceof CoreOp.TupleOp to) {
 87                             setName(to.operands().get(tlo.index()), name);
 88                         }
 89                     }
 90                 }
 91             }
 92         }
 93 
 94         private String baseName(Value value) {
 95             return "%" + baseNames.apply(value);
 96         }
 97 
 98         private String baseName(Value value, int elementIndex) {
 99             var name = baseName(value);
100             return elementIndex > 0 ? name + '.' + elementIndex : name;
101         }
102 
103         String nameOf(Value value) {
104             var name = baseName(value);
105             return remap.getOrDefault(name, name);
106         }
107 
108         String nameOf(Value tuple, int elementIndex) {
109             var name = baseName(tuple, elementIndex);
110             return remap.getOrDefault(name, name);
111         }
112 
113         void mapTupleLoad(Value tupleLoadResult, Value tuple, int elementIndex) {
114             remap.putIfAbsent(baseName(tupleLoadResult), nameOf(tuple, elementIndex));
115         }
116 
117         void mapTupleElements(Value tuple, List<Value> elements) {
118             for (int i = 0; i < elements.size(); i++) {
119                 remap.putIfAbsent(baseName(tuple, i), nameOf(elements.get(i)));
120             }
121         }
122     }
123 
124     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers) {
125         return buildModel(domain, module, initializers, Map.of(), _ -> null);
126     }
127 
128     public record ExternalTensorDataInfo(String location, long offset, long length) {
129     }
130 
131     public static byte[] buildModel(String domain, CoreOp.ModuleOp module, List<Object> initializers, Map<Value, String> explicitValueNames, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
132         var indexer = new Indexer(module, explicitValueNames);
133 
134         var functions = new ArrayList<>(module.functionTable().sequencedValues());
135         var imports = new ArrayList<String>();
136         if (functions.size() > 1) imports.add(domain); // self domain import if additional functions
137         for (var f : functions) {
138             for (var op : f.body().entryBlock().ops()) { // auto import of op domains
139                 if (op instanceof OnnxOp oop) {
140                     String name = oop.schema().name();
141                     int di = name.lastIndexOf('.');
142                     if (di > 0) {
143                         String dn = name.substring(0, di);
144                         if (!imports.contains(dn)) imports.add(dn);
145                     }
146                 }
147             }
148         }
149         var mainFunc = functions.removeLast();
150         var mainBlock = mainFunc.body().entryBlock();
151 
152         var model = buildModel(
153                 graph(domain, mainFunc.funcName(), indexer, mainBlock, initializers, 0, tensorDataExternalizer),
154                 imports,
155                 functions.stream().map(f ->
156                         function(domain, imports, f.funcName(),
157                                  expandTuples(indexer, f.parameters()),
158                                  expandTuples(indexer, f.body().entryBlock().terminatingOp().operands()),
159                                  nodes(domain, indexer, f.body().entryBlock().ops()))).toList());
160 
161 //        System.out.println(OnnxModel.readFrom(model).toText());
162         return model;
163     }
164 
165     // @@@ unchecked constraints:
166     //         tensor FuncOp parameters and single tensor return type
167     //         OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
168     //         entry block only
169     static byte[] buildModel(Block block, List<oracle.code.onnx.Tensor> initializers) {
170         var indexer = new Indexer(block.ancestorOp(), Map.of());
171         var model = buildModel(graph(null, null, indexer, block, initializers, 0), List.of(), List.of());
172 //        System.out.println(OnnxModel.readFrom(model).toText());
173         return model;
174     }
175 
176     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
177         return buildModel(graph(null, initializers, inputs, ops, outputNames), List.of(), List.of());
178     }
179 
180     static byte[] buildModel(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames, List<String> customImportDomains, List<FunctionProto> functions) {
181         return buildModel(graph(null, initializers, inputs, ops, outputNames), customImportDomains, functions);
182     }
183 
184     static byte[] buildModel(GraphProto graph, List<String> imports, List<FunctionProto> functions) {
185         return new ModelProto()
186                 .irVersion(IR_VERSION)
187                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
188                 .forEach(imports, (m, d) -> m.opsetImport(new OperatorSetIdProto().domain(d).version(1)))
189                 .forEach(functions, (m, f) -> m.functions(f))
190                 .graph(graph)
191                 .getBytes();
192     }
193 
194     static List<String> expandTuples(Indexer indexer, List<? extends Value> values) {
195         var names = new ArrayList<String>();
196         expandTuples(indexer, names, values);
197         return names;
198     }
199 
200     static void expandTuples(Indexer indexer, List<String> names, List<? extends Value> values) {
201         for (var v : values) {
202             if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleOp op) {
203                 expandTuples(indexer, names, op.operands());
204             } else if (v instanceof Op.Result or && or.op() instanceof CoreOp.TupleLoadOp op) {
205                 names.add(indexer.nameOf(op.operands().getFirst(), op.index()));
206             } else if (v.type() instanceof TupleType tt) {
207                 var ct = tt.componentTypes();
208                 for (int i = 0; i < ct.size(); i++) {
209                     names.add(indexer.nameOf(v, i));
210                 }
211             } else {
212                 names.add(indexer.nameOf(v));
213             }
214         }
215     }
216 
217     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs) {
218         return graph(domain, graphName, indexer, block, initializers, scalarArgs, _ -> null);
219     }
220 
221     static GraphProto graph(String domain, String graphName, Indexer indexer, Block block, List<? extends Object> initializers, int scalarArgs, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
222         var params = block.parameters();
223         params.forEach(indexer::nameOf);
224         int firstInitializer = params.size() - initializers.size();
225         var args = params.subList(0, firstInitializer);
226         return graph(graphName,
227                 IntStream.range(0, initializers.size()).boxed().<TensorProto>mapMulti((i, tps) -> {
228                     Object val = initializers.get(i);
229                     if (val instanceof Record) {
230                         var rcs = val.getClass().getRecordComponents();
231                         for (int rci = 0; rci < rcs.length; rci++) try {
232                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), rci), (Tensor)(rcs[rci].getAccessor().invoke(val)), tensorDataExternalizer));
233                         } catch (ReflectiveOperationException e) {
234                             throw new IllegalArgumentException(e);
235                         }
236                     } else if (val instanceof Tensor[] tarr) {
237                         for (int tai = 0; tai < tarr.length; tai++) {
238                             tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer), tai), tarr[tai], tensorDataExternalizer));
239                         }
240                     } else {
241                         tps.accept(tensorProto(indexer.nameOf(params.get(i + firstInitializer)), (Tensor)val, tensorDataExternalizer));
242                     }
243                 }).toList(),
244                 tensorInfos(indexer, args, scalarArgs),
245                 nodes(domain, indexer, block.ops()),
246                 expandTuples(indexer, block.terminatingOp().operands()));
247     }
248 
249     static List<String> opInputNames(Indexer indexer, SequencedMap<OnnxOp.OnnxParameter, Object> inputs) {
250         List<String> inputNames = inputs.sequencedValues().stream()
251                 .<String>mapMulti((v, dump) -> {
252                     switch (v) {
253                         case Value val -> dump.accept(indexer.nameOf(val));
254                         case java.util.Optional<?> o when o.isPresent() && o.get() instanceof Value val -> dump.accept(indexer.nameOf(val));
255                         case List l -> l.forEach(val -> dump.accept(indexer.nameOf((Value)val)));
256                         default -> dump.accept(""); // empty names for unused optional inputs
257                     }
258                 }).toList();
259         // trim trailing empty names
260         return inputNames.reversed().stream().dropWhile(String::isEmpty).toList().reversed();
261     }
262 
263     static List<NodeProto> nodes(String domain, Indexer indexer, List<Op> ops) {
264         return ops.stream().<NodeProto>mapMulti((op, opNodes) -> {
265             switch (op) {
266                 case OnnxOps.If ifOp ->
267                     opNodes.accept(node(
268                             ifOp.schema().name(),
269                             List.of(indexer.nameOf(ifOp.operands().getFirst())),
270                             IntStream.range(0, ifOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(ifOp.result(), o)).toList(),
271                             java.util.Map.of(
272                                     "then_branch", graph(domain, null, indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
273                                     "else_branch", graph(domain, null, indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
274                 case OnnxOps.Loop loopOp -> {
275                     opNodes.accept(node(loopOp.schema().name(),
276                             expandTuples(indexer, loopOp.operands()),
277                             IntStream.range(0, loopOp.resultType() instanceof TupleType tt ? tt.componentTypes().size() : 1).mapToObj(o -> indexer.nameOf(loopOp.result(), o)).toList(),
278                             java.util.Map.of(
279                                     "body", graph(domain, null, indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
280                 }
281                 case OnnxOp onnxOp ->
282                     opNodes.accept(node(
283                             onnxOp.schema().name(),
284                             opInputNames(indexer, onnxOp.onnxInputs()),
285                             IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.nameOf(onnxOp.result(), o)).toList(),
286                             onnxOp.onnxAttributes()));
287                 case CoreOp.FuncCallOp fco ->
288                     opNodes.accept(node(
289                             domain,
290                             fco.funcName(),
291                             expandTuples(indexer, fco.operands()),
292                             expandTuples(indexer, List.of(fco.result())),
293                             java.util.Map.of()));
294                 case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
295                 }
296                 case CoreOp.TupleLoadOp tlo ->
297                     indexer.mapTupleLoad(tlo.result(), tlo.operands().getFirst(), tlo.index());
298                 case CoreOp.TupleOp to ->
299                     indexer.mapTupleElements(to.result(), to.operands());
300                 case JavaOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
301                     if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
302                         indexer.mapTupleLoad(io.result(), io.operands().getFirst(), i);
303                     } else if (io.invokeDescriptor().name().equals("of")) {
304                         indexer.mapTupleElements(io.result(), io.operands());
305                     } else {
306                         throw new UnsupportedOperationException(op.toText());
307                     }
308                 }
309                 default -> {
310                     throw new UnsupportedOperationException(op.toText());
311                 }
312             }
313         }).toList();
314     }
315 
316     static List<ValueInfoProto> tensorInfos(Indexer indexer, List<Block.Parameter> args, int scalarArgs) {
317         var infos = new ArrayList<ValueInfoProto>();
318         for (var arg : args) {
319             switch (arg.type()) {
320                 case OnnxType.TensorType tt ->
321                     infos.add(tensorInfo(indexer.nameOf(arg), tt.eType().id(), infos.size() < scalarArgs));
322                 case TupleType tt -> {
323                     var ct = tt.componentTypes();
324                     for (int i = 0; i < ct.size(); i++) {
325                         infos.add(tensorInfo(indexer.nameOf(arg, i), ((OnnxType.TensorType)ct.get(i)).eType().id(), infos.size() < scalarArgs));
326                     }
327                 }
328                 default ->
329                     throw new UnsupportedOperationException(arg.type().toString());
330             }
331         }
332         return infos;
333     }
334 
335     static GraphProto graph(String name, List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
336         return new GraphProto()
337                 .name(name)
338                 .forEach(initializers, (g, i) -> g.initializer(i))
339                 .forEach(inputs, (g, i) -> g.input(i))
340                 .forEach(ops, (g, op) -> g.node(op))
341                 .forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName)));
342     }
343 
344     static FunctionProto function(String functionDomain, List<String> imports, String functionName, List<String> inputNames, List<String> outputNames, List<NodeProto> ops) {
345         int di = functionName.lastIndexOf('.');
346         return new FunctionProto()
347                 .domain(functionDomain)
348                 .name(functionName)
349                 .forEach(inputNames, (f, i) -> f.input(i))
350                 .forEach(ops, (g, op) -> g.node(op))
351                 .forEach(outputNames, (f, o) -> f.output(o))
352                 .opsetImport(new OperatorSetIdProto().version(OPSET_VERSION))
353                 .forEach(imports, (f, d) -> f.opsetImport(new OperatorSetIdProto().domain(d).version(1)));
354     }
355 
356     static NodeProto node(String domain, String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
357         return new NodeProto()
358                 .domain(domain)
359                 .opType(opName)
360                 .forEach(inputNames, (n, iName) -> n.input(iName))
361                 .forEach(attributes.entrySet(), (n, ae) -> n.attribute(attribute(ae.getKey(), ae.getValue())))
362                 .forEach(outputNames, (n, oName) -> n.output(oName));
363     }
364 
365     static NodeProto node(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {
366         int di = opName.lastIndexOf('.');
367         return node(di < 0 ? null : opName.substring(0, di), opName.substring(di + 1), inputNames, outputNames, attributes);
368     }
369 
370     static ValueInfoProto tensorInfo(String name, int tensorElementType) {
371         return tensorInfo(name, tensorElementType, false);
372     }
373 
374     static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
375         var t = new TypeProto.Tensor().elemType(tensorElementType);
376         if (addScalarShape) t.shape(new TensorShapeProto());
377         return new ValueInfoProto()
378                 .name(name)
379                 .type(new TypeProto().tensorType(t));
380     }
381 
382     static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor, Function<Tensor, ExternalTensorDataInfo> tensorDataExternalizer) {
383         ExternalTensorDataInfo extInfo = tensorDataExternalizer.apply(tensor);
384         TensorProto tp = new TensorProto()
385                 .name(name)
386                 .dataType(tensor.elementType().id)
387                 .dims(tensor.shape());
388         return extInfo == null
389                 ? tp.rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE))
390                 : tp.externalData(new StringStringEntryProto().key("location").value(extInfo.location()))
391                     .externalData(new StringStringEntryProto().key("offset").value(String.valueOf(extInfo.offset())))
392                     .externalData(new StringStringEntryProto().key("length").value(String.valueOf(extInfo.length())))
393                     .dataLocation(DataLocation.EXTERNAL);
394     }
395 
396     static TensorProto tensorProto(oracle.code.onnx.Tensor tensor) {
397         return new TensorProto()
398                 .dataType(tensor.elementType().id)
399                 .dims(tensor.shape())
400                 .rawData(tensor.data().toArray(ValueLayout.JAVA_BYTE));
401     }
402 
403     static AttributeProto attribute(String name, Object value) {
404         var attr = new AttributeProto().name(name);
405         switch (value) {
406             case Float f -> {
407                 attr.type(AttributeType.FLOAT).f(f);
408             }
409             case Long l -> {
410                 attr.type(AttributeType.INT).i(l);
411             }
412             case GraphProto g -> {
413                 attr.type(AttributeType.GRAPH).g(g.name(name));
414             }
415             case float[] floats -> {
416                 attr.type(AttributeType.FLOATS);
417                 attr.floats(floats);
418             }
419             case long[] longs -> {
420                 attr.type(AttributeType.INTS);
421                 attr.ints(longs);
422             }
423             case String s -> {
424                 attr.type(AttributeType.STRING);
425                 attr.s(s.getBytes());
426             }
427             case Tensor<?> t -> {
428                 attr.type(AttributeType.TENSOR);
429                 attr.t(tensorProto(t));
430             }
431             default -> {
432                 throw new UnsupportedOperationException(value.getClass().toString()); // @@@ ToDo
433             }
434         }
435         return attr;
436     }
437 }