1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx;
27
28 import java.io.File;
29 import java.io.IOException;
30 import java.lang.foreign.*;
31 import java.lang.invoke.MethodHandles;
32 import java.lang.invoke.VarHandle;
33 import java.lang.reflect.AccessFlag;
34 import java.lang.reflect.Field;
35 import java.lang.reflect.ParameterizedType;
36 import java.nio.file.Files;
37 import java.nio.file.Path;
38 import java.nio.file.StandardCopyOption;
39 import java.util.*;
40 import java.util.function.Consumer;
41 import java.util.function.Supplier;
42 import java.util.stream.IntStream;
43 import java.util.stream.Stream;
44
45 import jdk.incubator.code.*;
46
47 import jdk.incubator.code.dialect.core.CoreOp;
48 import jdk.incubator.code.dialect.core.TupleType;
49 import jdk.incubator.code.dialect.java.ArrayType;
50 import jdk.incubator.code.dialect.java.ClassType;
51 import jdk.incubator.code.dialect.java.FieldRef;
52 import jdk.incubator.code.dialect.java.JavaOp;
53 import jdk.incubator.code.dialect.java.JavaType;
54 import oracle.code.onnx.compiler.OnnxTransformer;
55 import oracle.code.onnx.foreign.OrtApi;
56 import oracle.code.onnx.foreign.OrtApiBase;
57
58 import static oracle.code.onnx.foreign.onnxruntime_c_api_h.*;
59
60 public final class OnnxRuntime {
61
62 static final boolean DEBUG = Boolean.getBoolean("oracle.code.onnx.OnnxRuntime.DEBUG");
63 static final System.Logger LOG = System.getLogger("oracle.code.onnx");
64 static final JavaType TENSOR_RAW_TYPE = JavaType.type(Tensor.class);
65 static final JavaType LIST_RAW_TYPE = JavaType.type(List.class);
66 private static final CachedSessionClassValue SESSION_CACHE = new CachedSessionClassValue();
67 private static final String LOG_ID = "onnx-ffm-java";
68 private static OnnxRuntime INSTANCE;
69
70 static {
71 try {
72 System.loadLibrary("onnxruntime");
73 } catch (UnsatisfiedLinkError _) {
74 // fallback to extract the library from onnxruntime dependency jar
75 String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH).startsWith("aarch64") ? "aarch64" : "x64";
76 String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
77 String libResource;
78 if (os.contains("mac") || os.contains("darwin")) {
79 libResource = "/ai/onnxruntime/native/osx-" + arch + "/libonnxruntime.dylib";
80 } else if (os.contains("win")) {
81 libResource = "/ai/onnxruntime/native/win-" + arch + "/libonnxruntime.dll";
82 } else if (os.contains("nux")) {
83 libResource = "/ai/onnxruntime/native/linux-" + arch + "/libonnxruntime.so";
84 } else {
85 throw new IllegalStateException("Unsupported os:" + os);
86 }
87 try {
88 // workaround to avoid CNFE when the ReleaseEnv class is attempted to load in the shutdown hook from already closed classloader
89 Class.forName("oracle.code.onnx.foreign.OrtApi$ReleaseEnv");
90 } catch (ClassNotFoundException e) {
91 throw new IllegalStateException(e);
92 }
93 try (var libStream = OnnxRuntime.class.getResourceAsStream(libResource)) {
94 var libFile = File.createTempFile("libonnxruntime", "");
95 Path libFilePath = libFile.toPath();
96 Files.copy(Objects.requireNonNull(libStream), libFilePath, StandardCopyOption.REPLACE_EXISTING);
97 System.load(libFilePath.toAbsolutePath().toString());
98 libFile.deleteOnExit();
99 } catch (IOException e) {
100 throw new RuntimeException(e);
101 }
102 }
103 LOG.log(System.Logger.Level.DEBUG, "ONNX Runtime initialized");
104 }
105
106 private final MemorySegment runtimeAddress, ret, envAddress, defaultAllocatorAddress;
107
108 private OnnxRuntime() {
109 var arena = Arena.ofAuto();
110 ret = arena.allocate(C_POINTER);
111 // const OrtApi* ortPtr = OrtGetApiBase()->GetApi((uint32_t)apiVersion);
112 var apiBase = OrtApiBase.reinterpret(OrtGetApiBase(), arena, null);
113 runtimeAddress = OrtApi.reinterpret(OrtApiBase.GetApi.invoke(OrtApiBase.GetApi(apiBase), ORT_API_VERSION()), arena, null);
114 envAddress = retAddr(OrtApi.CreateEnv.invoke(OrtApi.CreateEnv(runtimeAddress), ORT_LOGGING_LEVEL_ERROR(), arena.allocateFrom(LOG_ID), ret));
115 defaultAllocatorAddress = retAddr(OrtApi.GetAllocatorWithDefaultOptions.invoke(OrtApi.GetAllocatorWithDefaultOptions(runtimeAddress), ret)).reinterpret(arena, null);
116 Runtime.getRuntime().addShutdownHook(new Thread(() -> {
117 LOG.log(System.Logger.Level.DEBUG, "ONNX Runtime released");
118 OrtApi.ReleaseEnv.invoke(OrtApi.ReleaseEnv(runtimeAddress), envAddress);
119 }));
120 }
121
122 // @@@ temporary set public for ongoing experiments
123 public static List<Object> getInitValues(MethodHandles.Lookup lookup, SequencedCollection<FieldRef> initializers, SequencedCollection<Object> possibleReceivers) {
124 return initializers.stream().map(i -> {
125 try {
126 Field initializerField = i.resolveToField(lookup);
127 VarHandle handle = lookup.unreflectVarHandle(initializerField);
128 if (initializerField.accessFlags().contains(AccessFlag.STATIC)) {
129 return handle.get();
130 } else {
131 Class<?> initializerClass = initializerField.getDeclaringClass();
132 return handle.get(possibleReceivers.stream().filter(initializerClass::isInstance).findFirst().orElseThrow());
133 }
134 } catch (ReflectiveOperationException ex) {
135 throw new RuntimeException(ex);
136 }
137 }).toList();
138 }
139
140 public static <T> T execute(Supplier<T> codeLambda) {
141 return execute(MethodHandles.lookup(), codeLambda);
142 }
143
144 public static <T> T execute(MethodHandles.Lookup l, Supplier<T> codeLambda) {
145 return execute(Arena.ofAuto(), l, codeLambda);
146 }
147
148 private static void expandArg(Object val, Consumer<Tensor> args) {
149 switch (val) {
150 case CoreOp.Var<?> v -> expandArg(v.value(), args);
151 case Tensor t -> args.accept(t);
152 case Record r -> {
153 for (var rc : r.getClass().getRecordComponents())
154 try {
155 expandArg(rc.getAccessor().invoke(r), args);
156 } catch (ReflectiveOperationException e) {
157 throw new IllegalStateException(e);
158 }
159 }
160 // @@@ constant array last object must be consumed or the statically detected size and the actual size missmatch
161 case Object[] os -> {
162 for (var o : os) {
163 expandArg(o, args);
164 }
165 }
166 default -> {
167 }
168 }
169 }
170
171 public static <T> T execute(Arena arena, MethodHandles.Lookup l, Supplier<T> codeLambda) {
172 return execute(arena, l, codeLambda, null);
173 }
174
175 public static <T> T execute(Arena arena, MethodHandles.Lookup l, Supplier<T> codeLambda, SessionOptions options) {
176 var q = Op.ofLambda(codeLambda).orElseThrow();
177
178 var model = SESSION_CACHE.computeIfAbsent(codeLambda.getClass(), l, q, options);
179
180 List<Tensor> arguments = Stream.concat(q.capturedValues().sequencedValues().stream(),
181 model.bypassedInitValues().stream())
182 .mapMulti(OnnxRuntime::expandArg)
183 .toList();
184 LOG.log(System.Logger.Level.DEBUG, "Running ONNX session " + codeLambda.getClass().getSimpleName().split("\\$")[0]);
185 List<Tensor> ret = model.session().run(arena, arguments);
186
187 var lambdaOp = q.op();
188 CodeType type = lambdaOp.invokableSignature().returnType();
189 if (type instanceof ArrayType) {
190 return (T) ret.toArray(Tensor[]::new);
191 }
192 ClassType retType = ((ClassType) type).rawType();
193 if (retType.equals(TENSOR_RAW_TYPE)) {
194 return (T) ret.getFirst();
195 } else if (retType.equals(LIST_RAW_TYPE)) {
196 return (T) ret;
197 } else if (getRecordClass(l, retType) instanceof Class cls) {
198 try {
199 return (T) cls.getConstructors()[0].newInstance(unflat(ret, (TupleType) model.returnType()));
200 } catch (ReflectiveOperationException e) {
201 throw new IllegalStateException(e);
202 }
203 } else {
204 throw new UnsupportedOperationException("Unsupported return type: " + q.op().resultType());
205 }
206 }
207
208 static Object[] unflat(List<Tensor> values, TupleType returnTupleType) {
209 var returnTypes = returnTupleType.componentTypes();
210 Object[] ret = new Object[returnTypes.size()];
211 for (int i = 0, j = 0; i < ret.length; i++) {
212 ret[i] = returnTypes.get(i) instanceof TupleType tt ? values.subList(j, j += tt.componentTypes().size()).toArray(Tensor[]::new) : values.get(j++);
213 }
214 return ret;
215 }
216
217 static Class getRecordClass(MethodHandles.Lookup l, ClassType ct) {
218 try {
219 var t = ct.resolve(l);
220 while (t instanceof ParameterizedType pt) t = pt.getRawType();
221 if (t instanceof Class c && c.isRecord()) return c;
222 } catch (ReflectiveOperationException _) {
223 }
224 return null;
225 }
226
227 public static OnnxRuntime getInstance() {
228 if (INSTANCE == null) {
229 INSTANCE = new OnnxRuntime();
230 }
231 return INSTANCE;
232 }
233
234 private static MemorySegment autoShape(Arena arena, long[] shape, long elementsCount) {
235 int auto = -1;
236 long elCount = 1;
237 for (int i = 0; i < shape.length; i++) {
238 long dim = shape[i];
239 if (dim == -1) {
240 if (auto == -1) {
241 auto = i;
242 } else {
243 throw new IllegalArgumentException("Multiple automatic dimensions in shape");
244 }
245 } else {
246 elCount *= dim;
247 }
248 }
249 var ms = arena.allocateFrom(C_LONG_LONG, shape);
250 if (auto != -1) {
251 long autoDim = elementsCount / elCount;
252 ms.setAtIndex(C_LONG, auto, autoDim);
253 elCount *= autoDim;
254 }
255 if (elCount != elementsCount) {
256 throw new IllegalArgumentException("Tensor shape does not match data");
257 }
258 return ms;
259 }
260
261 public List<Tensor> runOp(Arena arena, String opName, List<Tensor> inputValues, int numOutputs, Map<String, Object> attributes) {
262 var outputNames = IntStream.range(0, numOutputs).mapToObj(o -> "o" + o).toList();
263 var protoModel = OnnxProtoBuilder.buildModel(
264 List.of(),
265 IntStream.range(0, inputValues.size()).mapToObj(i -> OnnxProtoBuilder.tensorInfo("i" + i, inputValues.get(i).elementType().id)).toList(),
266 List.of(OnnxProtoBuilder.node(
267 opName,
268 IntStream.range(0, inputValues.size()).mapToObj(i -> "i" + i).toList(),
269 outputNames,
270 attributes)),
271 outputNames);
272 return createSession(arena, protoModel)
273 .run(arena, inputValues);
274 }
275
276 public List<Tensor> run(Arena arena, Block block, List<Tensor> inputValues, int initializers) {
277 var protoModel = OnnxProtoBuilder.buildModel(block, inputValues.subList(0, initializers));
278 return createSession(arena, protoModel)
279 .run(arena, inputValues.subList(initializers, inputValues.size()));
280 }
281
282 public Session createSession(Arena arena, String modelPath) {
283 return createSession(arena, modelPath, createSessionOptions(arena));
284 }
285
286 public Session createSession(Arena arena, String modelPath, SessionOptions options) {
287 return new Session(arena, retAddr(OrtApi.CreateSession.invoke(OrtApi.CreateSession(runtimeAddress), envAddress, arena.allocateFrom(modelPath), options.sessionOptionsAddress, ret)));
288 }
289
290 public Session createSession(Arena arena, byte[] model) {
291 return createSession(arena, model, createSessionOptions(arena));
292 }
293
294 private Session createSession(Arena arena, byte[] model, SessionOptions options) {
295 return new Session(arena, retAddr(OrtApi.CreateSessionFromArray.invoke(OrtApi.CreateSessionFromArray(runtimeAddress), envAddress, arena.allocateFrom(ValueLayout.JAVA_BYTE, model), model.length, options.sessionOptionsAddress, ret)));
296 }
297
298 public MemorySegment createTensor(Arena arena, MemorySegment flatData, Tensor.ElementType elementType, long[] shape) {
299 var allocatorInfo = retAddr(OrtApi.AllocatorGetInfo.invoke(OrtApi.AllocatorGetInfo(runtimeAddress), defaultAllocatorAddress, ret));
300 return retAddr(OrtApi.CreateTensorWithDataAsOrtValue.invoke(
301 OrtApi.CreateTensorWithDataAsOrtValue(runtimeAddress),
302 allocatorInfo,
303 flatData, flatData.byteSize(),
304 shape.length == 0 ? MemorySegment.NULL : autoShape(arena, shape, 8l * flatData.byteSize() / elementType.bitSize()), (long) shape.length,
305 elementType.id,
306 ret)).reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
307 }
308
309 public MemorySegment createStringTensor(Arena arena, String[] data, long[] shape) {
310 var allocatorInfo = retAddr(OrtApi.AllocatorGetInfo.invoke(OrtApi.AllocatorGetInfo(runtimeAddress), defaultAllocatorAddress, ret));
311 MemorySegment flatDataPlacholder = arena.allocate(data.length * 24L); // @@@ calling of CreateTensorAsOrtValue crashes
312 var tensor = retAddr(OrtApi.CreateTensorWithDataAsOrtValue.invoke(
313 OrtApi.CreateTensorWithDataAsOrtValue(runtimeAddress),
314 allocatorInfo,
315 flatDataPlacholder, flatDataPlacholder.byteSize(),
316 autoShape(arena, shape, data.length), (long) shape.length,
317 Tensor.ElementType.STRING.id,
318 ret));
319 for (int i = 0; i < data.length; i++) {
320 checkStatus(OrtApi.FillStringTensorElement.invoke(
321 OrtApi.FillStringTensorElement(runtimeAddress),
322 tensor,
323 arena.allocateFrom(data[i]),
324 (long)i));
325 }
326 return tensor.reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
327 }
328
329 public Tensor.ElementType tensorElementType(MemorySegment tensorAddr) {
330 var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
331 return Tensor.ElementType.fromOnnxId(retInt(OrtApi.GetTensorElementType.invoke(OrtApi.GetTensorElementType(runtimeAddress), infoAddr, ret)));
332 }
333
334 public long[] tensorShape(MemorySegment tensorAddr) {
335 try (var arena = Arena.ofConfined()) {
336 var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
337 long dims = retLong(OrtApi.GetDimensionsCount.invoke(OrtApi.GetDimensionsCount(runtimeAddress), infoAddr, ret));
338 var shape = arena.allocate(C_LONG_LONG, dims);
339 checkStatus(OrtApi.GetDimensions.invoke(OrtApi.GetDimensions(runtimeAddress), infoAddr, shape, dims));
340 return shape.toArray(C_LONG_LONG);
341 }
342 }
343
344 public long tensorShapeElementCount(MemorySegment tensorAddr) {
345 var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
346 return retLong(OrtApi.GetTensorShapeElementCount.invoke(OrtApi.GetTensorShapeElementCount(runtimeAddress), infoAddr, ret));
347 }
348
349 public String stringTensorElement(MemorySegment tensorAddr, long elementIndex) {
350 try (var arena = Arena.ofConfined()) {
351 long elemLen = retLong(OrtApi.GetStringTensorElementLength.invoke(OrtApi.GetStringTensorElementLength(runtimeAddress), tensorAddr, elementIndex, ret));
352 var buf = arena.allocate(ValueLayout.JAVA_BYTE, elemLen + 1);
353 checkStatus(OrtApi.GetStringTensorElement.invoke(OrtApi.GetStringTensorElement(runtimeAddress), tensorAddr, elemLen, elementIndex, buf));
354 return buf.getString(0);
355 }
356 }
357
358 public MemorySegment tensorData(MemorySegment tensorAddr) {
359 var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
360 long size = retLong(OrtApi.GetTensorShapeElementCount.invoke(OrtApi.GetTensorShapeElementCount(runtimeAddress), infoAddr, ret))
361 * Tensor.ElementType.fromOnnxId(retInt(OrtApi.GetTensorElementType.invoke(OrtApi.GetTensorElementType(runtimeAddress), infoAddr, ret))).bitSize() / 8;
362 return retAddr(OrtApi.GetTensorMutableData.invoke(OrtApi.GetTensorMutableData(runtimeAddress), tensorAddr, ret))
363 .reinterpret(size);
364 }
365
366 public SessionOptions createSessionOptions(Arena arena) {
367 return new SessionOptions(retAddr(OrtApi.CreateSessionOptions.invoke(OrtApi.CreateSessionOptions(runtimeAddress), ret))
368 .reinterpret(arena, opts -> OrtApi.ReleaseSessionOptions.invoke(OrtApi.ReleaseSessionOptions(runtimeAddress), opts)));
369 }
370
371 public void appendExecutionProvider(Arena arena, SessionOptions sessionOptions, OnnxProvider provider) {
372 ExecutionProviderOptions executionProviderOptions = toNativeOptions(arena, provider);
373
374 MemorySegment funcPtr = OrtApi.SessionOptionsAppendExecutionProvider(runtimeAddress);
375 var status = OrtApi.SessionOptionsAppendExecutionProvider.invoke(
376 funcPtr,
377 sessionOptions.getSessionOptionsAddress(),
378 arena.allocateFrom(provider.name()),
379 executionProviderOptions.keySegment, executionProviderOptions.valSegment, executionProviderOptions.size);
380 checkStatus(status);
381 }
382
383 public void appendExecutionProvider_V2(Arena arena, SessionOptions sessionOptions, OnnxProvider provider) {
384 MemorySegment getEpDevicesFn = OrtApi.GetEpDevices(runtimeAddress);
385 MemorySegment devicesOut = arena.allocate(ValueLayout.ADDRESS);
386 MemorySegment countOut = arena.allocate(ValueLayout.JAVA_LONG);
387 checkStatus(OrtApi.GetEpDevices.invoke(getEpDevicesFn, envAddress, devicesOut, countOut));
388 long numDevices = countOut.get(ValueLayout.JAVA_LONG, 0);
389
390 MemorySegment devicesBasePtr = devicesOut.get(ValueLayout.ADDRESS, 0);
391 MemorySegment devicesArr = (numDevices > 0 && devicesBasePtr != MemorySegment.NULL)
392 ? devicesBasePtr.reinterpret(numDevices * ValueLayout.ADDRESS.byteSize())
393 : MemorySegment.NULL;
394
395 String target = provider.name();
396 var matches = new java.util.ArrayList<MemorySegment>();
397 if (devicesArr != MemorySegment.NULL) {
398 MemorySegment epNameFn = OrtApi.EpDevice_EpName(runtimeAddress);
399 for (int j = 0; j < numDevices; j++) {
400 MemorySegment dev = devicesArr.getAtIndex(ValueLayout.ADDRESS, j);
401 MemorySegment cstr = OrtApi.EpDevice_EpName.invoke(epNameFn, dev);
402 String epName = (cstr == MemorySegment.NULL) ? "" : cstr.getString(0);
403 if (target.equals(epName)) matches.add(dev);
404 }
405 }
406
407 if (matches.isEmpty()) {
408 appendExecutionProvider(arena, sessionOptions, provider);
409 } else {
410 ExecutionProviderOptions executionProviderOptions = toNativeOptions(arena, provider);
411
412 long deviceCount = matches.size();
413 MemorySegment deviceArrayPtr = arena.allocate(ValueLayout.ADDRESS, deviceCount);
414 for (int j = 0; j < matches.size(); j++)
415 deviceArrayPtr.setAtIndex(ValueLayout.ADDRESS, j, matches.get(j));
416
417 MemorySegment functionPtr = OrtApi.SessionOptionsAppendExecutionProvider_V2(runtimeAddress);
418 checkStatus(OrtApi.SessionOptionsAppendExecutionProvider_V2.invoke(
419 functionPtr,
420 sessionOptions.getSessionOptionsAddress(),
421 envAddress,
422 deviceArrayPtr, deviceCount,
423 executionProviderOptions.keySegment(), executionProviderOptions.valSegment(), executionProviderOptions.size()
424 ));
425 }
426 }
427
428 private static ExecutionProviderOptions toNativeOptions(Arena arena, OnnxProvider provider) {
429 var providerOptions = provider.options();
430 MemorySegment keySegment = MemorySegment.NULL;
431 MemorySegment valSegment = MemorySegment.NULL;
432 int size = 0;
433
434 if (Objects.nonNull(providerOptions) && !providerOptions.isEmpty()) {
435 size = providerOptions.size();
436 keySegment = arena.allocate(ValueLayout.ADDRESS, size);
437 valSegment = arena.allocate(ValueLayout.ADDRESS, size);
438 int i = 0;
439
440 for (Map.Entry<String, String> e : providerOptions.entrySet()) {
441 keySegment.setAtIndex(ValueLayout.ADDRESS, i, arena.allocateFrom(e.getKey()));
442 valSegment.setAtIndex(ValueLayout.ADDRESS, i, arena.allocateFrom(e.getValue()));
443 i++;
444 }
445 }
446 ExecutionProviderOptions executionProviderOptions = new ExecutionProviderOptions(keySegment, valSegment, size);
447 return executionProviderOptions;
448 }
449
450 private record ExecutionProviderOptions(MemorySegment keySegment, MemorySegment valSegment, int size) {}
451
452 private MemorySegment retAddr(MemorySegment res) {
453 checkStatus(res);
454 return ret.get(C_POINTER, 0);
455 }
456
457 private int retInt(MemorySegment res) {
458 checkStatus(res);
459 return ret.get(C_INT, 0);
460 }
461
462 private long retLong(MemorySegment res) {
463 checkStatus(res);
464 return ret.get(C_LONG_LONG, 0);
465 }
466
467 private String retString(MemorySegment res) {
468 return retAddr(res).reinterpret(Long.MAX_VALUE)
469 .getString(0);
470 }
471
472 private void checkStatus(MemorySegment status) {
473 try {
474 if (!status.equals(MemorySegment.NULL)) {
475 status = status.reinterpret(Long.MAX_VALUE);
476 if (status.get(C_INT, 0) != 0) {
477 throw new RuntimeException(status.getString(C_INT.byteSize()));
478 }
479 }
480 } finally {
481 OrtApi.ReleaseStatus.invoke(OrtApi.ReleaseStatus(runtimeAddress), status);
482 }
483 }
484
485 record SessionWithReturnType(Session session, CodeType returnType, List<Object> bypassedInitValues) {
486 }
487
488 static class CachedSessionClassValue extends ClassValue<SessionWithReturnType> {
489
490 private MethodHandles.Lookup l;
491 private Quoted<JavaOp.LambdaOp> q;
492 private SessionOptions options;
493
494 // Static helper for cache with options
495 protected SessionWithReturnType computeIfAbsent(
496 Class<?> lambdaClass, MethodHandles.Lookup l, Quoted<JavaOp.LambdaOp> q, SessionOptions options) {
497 try {
498 this.l = l;
499 this.q = q;
500 this.options = options;
501 // not very nice way to pass additional arguments to computeValue method
502 return get(lambdaClass);
503 } finally {
504 this.l = null;
505 this.q = null;
506 this.options = null;
507 }
508 }
509
510 @Override
511 protected SessionWithReturnType computeValue(Class<?> type) {
512 OnnxTransformer.ModuleAndInitializers mi = OnnxTransformer.transform(l, q);
513
514 String domainName = type.getSimpleName().split("\\$")[0];
515 boolean bypassInits = options != null && options.bypassInitilizers;
516 List<Object> initValues = getInitValues(l, mi.initializers(), q.capturedValues().sequencedValues());
517 LOG.log(System.Logger.Level.DEBUG, "Building ONNX binary " + domainName);
518 byte[] protobufModel = OnnxProtoBuilder.buildModel(domainName, mi.module(), bypassInits ? List.of() : initValues);
519
520 if (DEBUG) {
521 System.out.println(mi.module().toText());
522 // System.out.println(OnnxModel.readFrom(protobufModel).toText());
523 try {
524 var export = Path.of(domainName + ".onnx");
525 Files.write(export, protobufModel);
526 System.out.println("Onnx model exported to: " + export.toAbsolutePath());
527 } catch (IOException _) {
528 }
529 }
530
531 LOG.log(System.Logger.Level.DEBUG, "Creating ONNX session " + domainName);
532 // cached session must be created under its own auto arena
533 Session session = (options != null) ?
534 getInstance().createSession(Arena.ofAuto(), protobufModel, options) :
535 getInstance().createSession(Arena.ofAuto(), protobufModel);
536
537 return new SessionWithReturnType(
538 session,
539 mi.module().functionTable().lastEntry().getValue().invokableSignature().returnType(),
540 bypassInits ? initValues : List.of());
541
542
543 }
544 }
545
546 public final class Session {
547
548 private final MemorySegment sessionAddress;
549
550 private Session(Arena arena, MemorySegment sessionAddress) {
551 this.sessionAddress = sessionAddress.reinterpret(arena,
552 session -> OrtApi.ReleaseSession.invoke(OrtApi.ReleaseSession(runtimeAddress), session));
553 }
554
555 public int getNumberOfInputs() {
556 return retInt(OrtApi.SessionGetInputCount.invoke(OrtApi.SessionGetInputCount(runtimeAddress), sessionAddress, ret));
557 }
558
559 public String getInputName(int inputIndex) {
560 return retString(OrtApi.SessionGetInputName.invoke(OrtApi.SessionGetInputName(runtimeAddress), sessionAddress, inputIndex, defaultAllocatorAddress, ret));
561 }
562
563 public int getNumberOfOutputs() {
564 return retInt(OrtApi.SessionGetOutputCount.invoke(OrtApi.SessionGetOutputCount(runtimeAddress), sessionAddress, ret));
565 }
566
567 public String getOutputName(int inputIndex) {
568 return retString(OrtApi.SessionGetOutputName.invoke(OrtApi.SessionGetOutputName(runtimeAddress), sessionAddress, inputIndex, defaultAllocatorAddress, ret));
569 }
570
571 // @@@ only tensors are supported yet
572 public List<Tensor> run(Arena arena, List<Tensor> inputValues) {
573 var runOptions = MemorySegment.NULL;
574 int inputLen = getNumberOfInputs();
575 int outputLen = getNumberOfOutputs();
576 var inputNames = arena.allocate(C_POINTER, inputLen);
577 var inputs = arena.allocate(C_POINTER, inputLen);
578 long index = 0;
579 for (int i = 0; i < inputLen; i++) {
580 inputNames.setAtIndex(C_POINTER, index, arena.allocateFrom(getInputName(i)));
581 inputs.setAtIndex(C_POINTER, index++, inputValues.get(i).tensorAddr);
582 }
583 var outputNames = arena.allocate(C_POINTER, outputLen);
584 var outputs = arena.allocate(C_POINTER, outputLen);
585 for (int i = 0; i < outputLen; i++) {
586 outputNames.setAtIndex(C_POINTER, i, arena.allocateFrom(getOutputName(i)));
587 outputs.setAtIndex(C_POINTER, i, MemorySegment.NULL);
588 }
589 checkStatus(OrtApi.Run.invoke(OrtApi.Run(runtimeAddress), sessionAddress, runOptions, inputNames, inputs, (long) inputLen, outputNames, (long) outputLen, outputs));
590 var retArr = new Tensor[outputLen];
591 for (int i = 0; i < outputLen; i++) {
592 var tensorAddr = outputs.getAtIndex(C_POINTER, i)
593 .reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
594 retArr[i] = new Tensor(tensorData(tensorAddr).reinterpret(arena, null),
595 tensorAddr);
596 }
597 return List.of(retArr);
598 }
599 }
600
601 public final class SessionOptions {
602
603 private final MemorySegment sessionOptionsAddress;
604
605 boolean bypassInitilizers = false;
606
607 public SessionOptions(MemorySegment sessionOptionsAddress) {
608 this.sessionOptionsAddress = sessionOptionsAddress;
609 setInterOpNumThreads(1);
610 }
611
612 public void registerCustomOpsLibrary(MemorySegment path) {
613 checkStatus(OrtApi.RegisterCustomOpsLibrary_V2.invoke(OrtApi.RegisterCustomOpsLibrary_V2(runtimeAddress), sessionOptionsAddress, path));
614 }
615
616 public void setInterOpNumThreads(int numThreads) {
617 checkStatus(OrtApi.SetInterOpNumThreads.invoke(OrtApi.SetInterOpNumThreads(runtimeAddress), sessionOptionsAddress, numThreads));
618 }
619
620 public void setIntraOpNumThreads(int numThreads) {
621 checkStatus(OrtApi.SetIntraOpNumThreads.invoke(OrtApi.SetIntraOpNumThreads(runtimeAddress), sessionOptionsAddress, numThreads));
622 }
623
624 public void setSessionExecutionMode(int executionMode) {
625 checkStatus(OrtApi.SetSessionExecutionMode.invoke(OrtApi.SetSessionExecutionMode(runtimeAddress), sessionOptionsAddress, executionMode));
626 }
627
628 public MemorySegment getSessionOptionsAddress() {
629 return sessionOptionsAddress;
630 }
631
632 public void bypassInitizers() {
633 bypassInitilizers = true;
634 }
635 }
636 }