1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx.lift;
 25 
 26 import java.lang.reflect.Method;
 27 import java.lang.foreign.ValueLayout;
 28 import java.util.List;
 29 import java.util.Map;
 30 import java.util.Optional;
 31 import java.util.SequencedMap;
 32 import java.util.function.Function;
 33 import java.util.stream.Collectors;
 34 import java.util.stream.IntStream;
 35 import java.util.stream.LongStream;
 36 import jdk.incubator.code.Block;
 37 import jdk.incubator.code.CodeItem;
 38 import jdk.incubator.code.Op;
 39 import jdk.incubator.code.TypeElement;
 40 import jdk.incubator.code.Value;
 41 import jdk.incubator.code.dialect.core.CoreOp;
 42 import jdk.incubator.code.dialect.java.JavaType;
 43 import oracle.code.onnx.OnnxOperators;
 44 import oracle.code.onnx.Tensor;
 45 import oracle.code.onnx.ir.OnnxOp;
 46 import oracle.code.onnx.ir.OnnxType;
 47 import oracle.code.onnx.proto.OnnxModel;
 48 
 49 final class JavaTemplate {
 50 
 51     private static final String WEIGHT_FIELD_TEMPLATE = """
 52             final %s %s = load("%s", %dL, %d, %s%s);
 53         """;
 54 
 55     private static final String TEMPLATE = """
 56         import java.io.IOException;
 57         import java.io.RandomAccessFile;
 58         import java.lang.foreign.Arena;
 59         import java.lang.foreign.MemorySegment;
 60         import java.nio.channels.FileChannel;
 61         import java.util.List;
 62         import jdk.incubator.code.Reflect;
 63         import oracle.code.onnx.Tensor;
 64 
 65         import static java.util.Optional.*;
 66         import static oracle.code.onnx.OnnxOperators.*;
 67         import static oracle.code.onnx.Tensor.ElementType.*;
 68 
 69         public class %s {
 70 
 71             final Arena arena = Arena.ofAuto();
 72 
 73             MemorySegment mmap(String pathname, long offset, long length) {
 74                 try (var f = new RandomAccessFile(pathname, "r")) {
 75                     return f.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, length < 0 ? f.length() : length, arena);
 76                 } catch (IOException e) {
 77                     throw new RuntimeException(e);
 78                 }
 79             }
 80 
 81             <T> Tensor<T> load(String path, long offset, long length, Tensor.ElementType type, long... shape) {
 82                 return new Tensor<>(arena, mmap(path, offset, length), type, shape);
 83             }
 84 
 85         %s
 86             @Reflect
 87             public Object mainGraph(
 88         %s) {%s
 89             }
 90         }
 91         """;
 92 
 93     static String toJava(OnnxLift.LiftedModelWrapper model, String className) {
 94         Block entryBlock = model.func().bodies().getFirst().entryBlock();
 95         List<Block.Parameter> parameters = entryBlock.parameters();
 96         Function<CodeItem, String> namer = OnnxLift.namer(model.names());
 97         parameters.forEach(namer::apply); // initialize namer with all parameters first
 98 
 99         return TEMPLATE.formatted(
100                 className,
101                 weightFields(namer, parameters, model.weights()),
102                 parameters(namer, parameters, model.weights()),
103                 body(namer, entryBlock.ops()));
104     }
105 
106     private static String weightFields(Function<CodeItem, String> namer, List<Block.Parameter> parameters, List<OnnxModel.TensorProto> weights) {
107         StringBuilder out = new StringBuilder();
108         List<jdk.incubator.code.Block.Parameter> weightParams = parameters.subList(parameters.size() - weights.size(), parameters.size());
109         Map<String, oracle.code.onnx.proto.OnnxModel.TensorProto> wMap = weights.stream().collect(Collectors.toUnmodifiableMap(OnnxModel.TensorProto::name, Function.identity()));
110         for (int i = 0; i < weightParams.size(); i++) {
111             Block.Parameter wp = weightParams.get(i);
112             OnnxModel.TensorProto w = wMap.get(namer.apply(wp));
113             String name = OnnxLift.toJavaName(w.name());
114             long[] dims = OnnxLift.joinLongArray(w.dims());
115             String location;
116             long offset;
117             int length;
118             if (w.externalData() instanceof List<OnnxModel.StringStringEntryProto> ssep) {
119                 var map = ssep.stream().collect(Collectors.toUnmodifiableMap(OnnxModel.StringStringEntryProto::key, OnnxModel.StringStringEntryProto::value));
120                 location = map.get("location");
121                 offset = Long.parseLong(map.get("offset"));
122                 length = Integer.parseInt(map.get("length"));
123             } else {
124                 location = name;
125                 offset = 0;
126                 length = -1;
127             }
128             out.append(WEIGHT_FIELD_TEMPLATE.formatted(toJavaType(wp.type()), name, location, offset, length, Tensor.ElementType.fromOnnxId(w.dataType()).name(), dims.length > 0 ? (", " + longJoin(dims)) : ""));
129         }
130         return out.toString();
131     }
132 
133     private static String parameters(Function<CodeItem, String> namer, List<Block.Parameter> parameters, List<OnnxModel.TensorProto> weights) {
134         StringBuilder out = new StringBuilder();
135         int realParamsSize = parameters.size() - weights.size();
136         for (int i = 0; i < realParamsSize; i++) {
137             if (i > 0) {
138                 out.append(",\n");
139             }
140             Block.Parameter param = parameters.get(i);
141             out.append("            ").append(toJavaType(param.type())).append(' ').append(OnnxLift.toJavaName(namer.apply(param)));
142         }
143         return out.toString();
144     }
145 
146     private static String body(Function<CodeItem, String> namer, List<Op> ops) {
147         StringBuilder out = new StringBuilder();
148         for (jdk.incubator.code.Op op : ops) {
149             if (!(op instanceof CoreOp.TupleLoadOp)) {
150                 // lazy tupple loads
151                 out.append("\n        ");
152                 if (!op.resultType().equals(JavaType.VOID)) {
153                     out.append(toJavaType(op.resultType())).append(' ').append(OnnxLift.toJavaName(namer.apply(op.result()))).append(" = ");
154                 }
155                 switch (op) {
156                     case OnnxOp oo -> {
157                         String opName = op.externalizeOpName();
158                         out.append(opName.substring(opName.lastIndexOf('.') + 1)).append('(');
159                         OnnxOp.OnnxSchema schema = getSchema(oo);
160                         SequencedMap<OnnxOp.OnnxParameter, Object> inputs = oo.onnxInputs();
161                         boolean first = true;
162                         for (OnnxOp.OnnxParameter oi : schema.inputs()) {
163                             if (first) {
164                                 first = false;
165                             } else {
166                                 out.append(", ");
167                             }
168                             out.append(toJava(namer, inputs.get(oi)));
169                         }
170                         Map<String, Object> attrs = oo.onnxAttributes();
171                         for (OnnxOp.OnnxAttribute oa : schema.attributes()) {
172                             if (first) {
173                                 first = false;
174                             } else {
175                                 out.append(", ");
176                             }
177                             Object a = attrs.get(oa.name());
178                             if (a == null) {
179                                 out.append("empty()");
180                             } else if (oa.isOptional()) {
181                                 out.append("of(").append(toString(a)).append(')');
182                             } else {
183                                 out.append(toString(a));
184                             }
185                         }
186                         out.append(");");
187                     }
188                     case CoreOp.TupleOp to -> {
189                         out.append("List.of(");
190                         boolean first = true;
191                         for (jdk.incubator.code.Value te : to.operands()) {
192                             if (first) {
193                                 first = false;
194                             } else {
195                                 out.append(", ");
196                             }
197                             out.append(toJava(namer, te));
198                         }
199                         out.append(");");
200                     }
201                     case CoreOp.ReturnOp ro -> {
202                         out.append("return ").append(toJava(namer, ro.operands().getFirst())).append(';');
203                     }
204                     default -> throw new UnsupportedOperationException(op.toText());
205                 }
206             } else {
207                 namer.apply(op.result());
208             }
209         }
210         return out.toString();
211     }
212 
213     private static String toString(Object o) {
214         return switch (o) {
215             case long[] la -> newArray(la);
216             case float[] fa -> newArray(fa);
217             case Long l -> l.toString() + "L";
218             case Float f -> f.toString() + "F";
219             case String s -> "\"" + s + "\"";
220             case Tensor t -> "Tensor.ofShape(" + newArray(t.shape()) + ", " + toString(getData(t)) + ")";
221             default -> o.toString();
222         };
223     }
224 
225     private static Object getData(Tensor t) {
226         return switch (t.elementType()) {
227             case FLOAT -> t.data().toArray(ValueLayout.JAVA_FLOAT);
228             case INT64 -> t.data().toArray(ValueLayout.JAVA_LONG);
229             default -> throw new UnsupportedOperationException(t.elementType().name() + " tensor type");
230         };
231     }
232 
233     private static String newArray(long[] la) {
234         for (long l : la) {
235             if (l != 0l) {
236                 return "new long[] {" + longJoin(la) + "}";
237             }
238         }
239         return "new long[" + la.length + "]";
240     }
241 
242     private static String longJoin(long[] la) {
243         return LongStream.of(la).mapToObj(d -> String.valueOf(d) + "L").collect(Collectors.joining(", "));
244     }
245 
246     private static String newArray(float[] fa) {
247         for (float f : fa) {
248             if (f != 0f) {
249                 return IntStream.range(0, fa.length).mapToObj(i -> String.valueOf(fa[i]) + "F").collect(Collectors.joining(", ", "new float[] {", "}"));
250             }
251         }
252         return "new float[" + fa.length + "]";
253     }
254 
255     private static String tupleAccessor(Value tuple, int componentIndex) {
256         if (tuple instanceof Op.Result or && or.op() instanceof OnnxOp oo) {
257             String mName = oo.externalizeOpName();
258             mName = mName.substring(mName.lastIndexOf('.') + 1);
259             for (Method m : OnnxOperators.class.getMethods()) {
260                 if (m.getName().equals(mName)) {
261                     return m.getReturnType().getRecordComponents()[componentIndex].getAccessor().getName() + "()";
262                 }
263             }
264             throw new IllegalStateException(mName);
265         }
266         return "get(" + componentIndex + ")"; // fallback to List
267     }
268 
269     private static String toJava(Function<CodeItem, String> namer, Object value) {
270         return switch (value) {
271             case Optional o when o.isEmpty() -> "empty()";
272             case Optional o -> "of(" + toJava(namer, o.get()) + ")";
273             case List l -> "List.of(" + l.stream().map(le -> toJava(namer, le)).collect(Collectors.joining(", ")) + ")";
274             case Op.Result or when or.op() instanceof CoreOp.TupleLoadOp tlo -> OnnxLift.toJavaName(namer.apply(tlo.operands().getFirst())) + '.' + tupleAccessor(tlo.operands().getFirst(), tlo.index());
275             case Value v -> OnnxLift.toJavaName(namer.apply(v));
276             default -> throw new UnsupportedOperationException(value.toString());
277         };
278     }
279 
280     private static OnnxOp.OnnxSchema getSchema(OnnxOp oo) {
281         try {
282             return (OnnxOp.OnnxSchema) oo.getClass().getDeclaredField("SCHEMA").get(null);
283         } catch (ReflectiveOperationException ex) {
284             throw new RuntimeException(ex);
285         }
286     }
287 
288     private static String toJavaType(TypeElement t) {
289         return switch (t) {
290             case OnnxType.TensorType tt ->
291                 "Tensor<" + switch (tt.eType()) {
292                     case OnnxType.Float32Type _ -> "Float";
293                     case OnnxType.Int64Type _ -> "Long";
294                     case OnnxType.Int32Type _ -> "Integer";
295                     case OnnxType.UInt8Type _ -> "Byte";
296                     case OnnxType.BoolType _ -> "Boolean";
297                     case OnnxType.StringType _ -> "String";
298                     default -> throw new UnsupportedOperationException(t.toString());
299                 } + ">";
300             default -> "var";
301         };
302     }
303 }