1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx.lift;
25
26 import java.lang.reflect.Method;
27 import java.lang.foreign.ValueLayout;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Optional;
31 import java.util.SequencedMap;
32 import java.util.function.Function;
33 import java.util.stream.Collectors;
34 import java.util.stream.IntStream;
35 import java.util.stream.LongStream;
36 import jdk.incubator.code.Block;
37 import jdk.incubator.code.CodeItem;
38 import jdk.incubator.code.Op;
39 import jdk.incubator.code.TypeElement;
40 import jdk.incubator.code.Value;
41 import jdk.incubator.code.dialect.core.CoreOp;
42 import jdk.incubator.code.dialect.java.JavaType;
43 import oracle.code.onnx.OnnxOperators;
44 import oracle.code.onnx.Tensor;
45 import oracle.code.onnx.ir.OnnxOp;
46 import oracle.code.onnx.ir.OnnxType;
47 import oracle.code.onnx.proto.OnnxModel;
48
49 final class JavaTemplate {
50
51 private static final String WEIGHT_FIELD_TEMPLATE = """
52 final %s %s = load("%s", %dL, %d, %s%s);
53 """;
54
55 private static final String TEMPLATE = """
56 import java.io.IOException;
57 import java.io.RandomAccessFile;
58 import java.lang.foreign.Arena;
59 import java.lang.foreign.MemorySegment;
60 import java.nio.channels.FileChannel;
61 import java.util.List;
62 import jdk.incubator.code.Reflect;
63 import oracle.code.onnx.Tensor;
64
65 import static java.util.Optional.*;
66 import static oracle.code.onnx.OnnxOperators.*;
67 import static oracle.code.onnx.Tensor.ElementType.*;
68
69 public class %s {
70
71 final Arena arena = Arena.ofAuto();
72
73 MemorySegment mmap(String pathname, long offset, long length) {
74 try (var f = new RandomAccessFile(pathname, "r")) {
75 return f.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, length < 0 ? f.length() : length, arena);
76 } catch (IOException e) {
77 throw new RuntimeException(e);
78 }
79 }
80
81 <T> Tensor<T> load(String path, long offset, long length, Tensor.ElementType type, long... shape) {
82 return new Tensor<>(arena, mmap(path, offset, length), type, shape);
83 }
84
85 %s
86 @Reflect
87 public Object mainGraph(
88 %s) {%s
89 }
90 }
91 """;
92
93 static String toJava(OnnxLift.LiftedModelWrapper model, String className) {
94 Block entryBlock = model.func().bodies().getFirst().entryBlock();
95 List<Block.Parameter> parameters = entryBlock.parameters();
96 Function<CodeItem, String> namer = OnnxLift.namer(model.names());
97 parameters.forEach(namer::apply); // initialize namer with all parameters first
98
99 return TEMPLATE.formatted(
100 className,
101 weightFields(namer, parameters, model.weights()),
102 parameters(namer, parameters, model.weights()),
103 body(namer, entryBlock.ops()));
104 }
105
106 private static String weightFields(Function<CodeItem, String> namer, List<Block.Parameter> parameters, List<OnnxModel.TensorProto> weights) {
107 StringBuilder out = new StringBuilder();
108 List<jdk.incubator.code.Block.Parameter> weightParams = parameters.subList(parameters.size() - weights.size(), parameters.size());
109 Map<String, oracle.code.onnx.proto.OnnxModel.TensorProto> wMap = weights.stream().collect(Collectors.toUnmodifiableMap(OnnxModel.TensorProto::name, Function.identity()));
110 for (int i = 0; i < weightParams.size(); i++) {
111 Block.Parameter wp = weightParams.get(i);
112 OnnxModel.TensorProto w = wMap.get(namer.apply(wp));
113 String name = OnnxLift.toJavaName(w.name());
114 long[] dims = OnnxLift.joinLongArray(w.dims());
115 String location;
116 long offset;
117 int length;
118 if (w.externalData() instanceof List<OnnxModel.StringStringEntryProto> ssep) {
119 var map = ssep.stream().collect(Collectors.toUnmodifiableMap(OnnxModel.StringStringEntryProto::key, OnnxModel.StringStringEntryProto::value));
120 location = map.get("location");
121 offset = Long.parseLong(map.get("offset"));
122 length = Integer.parseInt(map.get("length"));
123 } else {
124 location = name;
125 offset = 0;
126 length = -1;
127 }
128 out.append(WEIGHT_FIELD_TEMPLATE.formatted(toJavaType(wp.type()), name, location, offset, length, Tensor.ElementType.fromOnnxId(w.dataType()).name(), dims.length > 0 ? (", " + longJoin(dims)) : ""));
129 }
130 return out.toString();
131 }
132
133 private static String parameters(Function<CodeItem, String> namer, List<Block.Parameter> parameters, List<OnnxModel.TensorProto> weights) {
134 StringBuilder out = new StringBuilder();
135 int realParamsSize = parameters.size() - weights.size();
136 for (int i = 0; i < realParamsSize; i++) {
137 if (i > 0) {
138 out.append(",\n");
139 }
140 Block.Parameter param = parameters.get(i);
141 out.append(" ").append(toJavaType(param.type())).append(' ').append(OnnxLift.toJavaName(namer.apply(param)));
142 }
143 return out.toString();
144 }
145
146 private static String body(Function<CodeItem, String> namer, List<Op> ops) {
147 StringBuilder out = new StringBuilder();
148 for (jdk.incubator.code.Op op : ops) {
149 if (!(op instanceof CoreOp.TupleLoadOp)) {
150 // lazy tupple loads
151 out.append("\n ");
152 if (!op.resultType().equals(JavaType.VOID)) {
153 out.append(toJavaType(op.resultType())).append(' ').append(OnnxLift.toJavaName(namer.apply(op.result()))).append(" = ");
154 }
155 switch (op) {
156 case OnnxOp oo -> {
157 String opName = op.externalizeOpName();
158 out.append(opName.substring(opName.lastIndexOf('.') + 1)).append('(');
159 OnnxOp.OnnxSchema schema = getSchema(oo);
160 SequencedMap<OnnxOp.OnnxParameter, Object> inputs = oo.onnxInputs();
161 boolean first = true;
162 for (OnnxOp.OnnxParameter oi : schema.inputs()) {
163 if (first) {
164 first = false;
165 } else {
166 out.append(", ");
167 }
168 out.append(toJava(namer, inputs.get(oi)));
169 }
170 Map<String, Object> attrs = oo.onnxAttributes();
171 for (OnnxOp.OnnxAttribute oa : schema.attributes()) {
172 if (first) {
173 first = false;
174 } else {
175 out.append(", ");
176 }
177 Object a = attrs.get(oa.name());
178 if (a == null) {
179 out.append("empty()");
180 } else if (oa.isOptional()) {
181 out.append("of(").append(toString(a)).append(')');
182 } else {
183 out.append(toString(a));
184 }
185 }
186 out.append(");");
187 }
188 case CoreOp.TupleOp to -> {
189 out.append("List.of(");
190 boolean first = true;
191 for (jdk.incubator.code.Value te : to.operands()) {
192 if (first) {
193 first = false;
194 } else {
195 out.append(", ");
196 }
197 out.append(toJava(namer, te));
198 }
199 out.append(");");
200 }
201 case CoreOp.ReturnOp ro -> {
202 out.append("return ").append(toJava(namer, ro.operands().getFirst())).append(';');
203 }
204 default -> throw new UnsupportedOperationException(op.toText());
205 }
206 } else {
207 namer.apply(op.result());
208 }
209 }
210 return out.toString();
211 }
212
213 private static String toString(Object o) {
214 return switch (o) {
215 case long[] la -> newArray(la);
216 case float[] fa -> newArray(fa);
217 case Long l -> l.toString() + "L";
218 case Float f -> f.toString() + "F";
219 case String s -> "\"" + s + "\"";
220 case Tensor t -> "Tensor.ofShape(" + newArray(t.shape()) + ", " + toString(getData(t)) + ")";
221 default -> o.toString();
222 };
223 }
224
225 private static Object getData(Tensor t) {
226 return switch (t.elementType()) {
227 case FLOAT -> t.data().toArray(ValueLayout.JAVA_FLOAT);
228 case INT64 -> t.data().toArray(ValueLayout.JAVA_LONG);
229 default -> throw new UnsupportedOperationException(t.elementType().name() + " tensor type");
230 };
231 }
232
233 private static String newArray(long[] la) {
234 for (long l : la) {
235 if (l != 0l) {
236 return "new long[] {" + longJoin(la) + "}";
237 }
238 }
239 return "new long[" + la.length + "]";
240 }
241
242 private static String longJoin(long[] la) {
243 return LongStream.of(la).mapToObj(d -> String.valueOf(d) + "L").collect(Collectors.joining(", "));
244 }
245
246 private static String newArray(float[] fa) {
247 for (float f : fa) {
248 if (f != 0f) {
249 return IntStream.range(0, fa.length).mapToObj(i -> String.valueOf(fa[i]) + "F").collect(Collectors.joining(", ", "new float[] {", "}"));
250 }
251 }
252 return "new float[" + fa.length + "]";
253 }
254
255 private static String tupleAccessor(Value tuple, int componentIndex) {
256 if (tuple instanceof Op.Result or && or.op() instanceof OnnxOp oo) {
257 String mName = oo.externalizeOpName();
258 mName = mName.substring(mName.lastIndexOf('.') + 1);
259 for (Method m : OnnxOperators.class.getMethods()) {
260 if (m.getName().equals(mName)) {
261 return m.getReturnType().getRecordComponents()[componentIndex].getAccessor().getName() + "()";
262 }
263 }
264 throw new IllegalStateException(mName);
265 }
266 return "get(" + componentIndex + ")"; // fallback to List
267 }
268
269 private static String toJava(Function<CodeItem, String> namer, Object value) {
270 return switch (value) {
271 case Optional o when o.isEmpty() -> "empty()";
272 case Optional o -> "of(" + toJava(namer, o.get()) + ")";
273 case List l -> "List.of(" + l.stream().map(le -> toJava(namer, le)).collect(Collectors.joining(", ")) + ")";
274 case Op.Result or when or.op() instanceof CoreOp.TupleLoadOp tlo -> OnnxLift.toJavaName(namer.apply(tlo.operands().getFirst())) + '.' + tupleAccessor(tlo.operands().getFirst(), tlo.index());
275 case Value v -> OnnxLift.toJavaName(namer.apply(v));
276 default -> throw new UnsupportedOperationException(value.toString());
277 };
278 }
279
280 private static OnnxOp.OnnxSchema getSchema(OnnxOp oo) {
281 try {
282 return (OnnxOp.OnnxSchema) oo.getClass().getDeclaredField("SCHEMA").get(null);
283 } catch (ReflectiveOperationException ex) {
284 throw new RuntimeException(ex);
285 }
286 }
287
288 private static String toJavaType(TypeElement t) {
289 return switch (t) {
290 case OnnxType.TensorType tt ->
291 "Tensor<" + switch (tt.eType()) {
292 case OnnxType.Float32Type _ -> "Float";
293 case OnnxType.Int64Type _ -> "Long";
294 case OnnxType.Int32Type _ -> "Integer";
295 case OnnxType.UInt8Type _ -> "Byte";
296 case OnnxType.BoolType _ -> "Boolean";
297 case OnnxType.StringType _ -> "String";
298 default -> throw new UnsupportedOperationException(t.toString());
299 } + ">";
300 default -> "var";
301 };
302 }
303 }