1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.io.File;
 29 import java.io.IOException;
 30 import java.lang.foreign.*;
 31 import java.lang.invoke.MethodHandles;
 32 import java.lang.invoke.VarHandle;
 33 import java.lang.reflect.AccessFlag;
 34 import java.lang.reflect.Field;
 35 import java.lang.reflect.ParameterizedType;
 36 import java.nio.file.Files;
 37 import java.nio.file.Path;
 38 import java.nio.file.StandardCopyOption;
 39 import java.util.*;
 40 import java.util.function.Consumer;
 41 import java.util.function.Supplier;
 42 import java.util.stream.IntStream;
 43 import java.util.stream.Stream;
 44 
 45 import jdk.incubator.code.*;
 46 
 47 import jdk.incubator.code.dialect.core.CoreOp;
 48 import jdk.incubator.code.dialect.core.TupleType;
 49 import jdk.incubator.code.dialect.java.ArrayType;
 50 import jdk.incubator.code.dialect.java.ClassType;
 51 import jdk.incubator.code.dialect.java.FieldRef;
 52 import jdk.incubator.code.dialect.java.JavaOp;
 53 import jdk.incubator.code.dialect.java.JavaType;
 54 import oracle.code.onnx.compiler.OnnxTransformer;
 55 import oracle.code.onnx.foreign.OrtApi;
 56 import oracle.code.onnx.foreign.OrtApiBase;
 57 
 58 import static oracle.code.onnx.foreign.onnxruntime_c_api_h.*;
 59 
 60 public final class OnnxRuntime {
 61 
 62     static final boolean DEBUG = Boolean.getBoolean("oracle.code.onnx.OnnxRuntime.DEBUG");
 63     static final System.Logger LOG = System.getLogger("oracle.code.onnx");
 64     static final JavaType TENSOR_RAW_TYPE = JavaType.type(Tensor.class);
 65     static final JavaType LIST_RAW_TYPE = JavaType.type(List.class);
 66     private static final CachedSessionClassValue SESSION_CACHE = new CachedSessionClassValue();
 67     private static final String LOG_ID = "onnx-ffm-java";
 68     private static OnnxRuntime INSTANCE;
 69 
 70     static {
 71         try {
 72             System.loadLibrary("onnxruntime");
 73         } catch (UnsatisfiedLinkError _) {
 74             // fallback to extract the library from onnxruntime dependency jar
 75             String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH).startsWith("aarch64") ? "aarch64" : "x64";
 76             String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
 77             String libResource;
 78             if (os.contains("mac") || os.contains("darwin")) {
 79                 libResource = "/ai/onnxruntime/native/osx-" + arch + "/libonnxruntime.dylib";
 80             } else if (os.contains("win")) {
 81                 libResource = "/ai/onnxruntime/native/win-" + arch + "/libonnxruntime.dll";
 82             } else if (os.contains("nux")) {
 83                 libResource = "/ai/onnxruntime/native/linux-" + arch + "/libonnxruntime.so";
 84             } else {
 85                 throw new IllegalStateException("Unsupported os:" + os);
 86             }
 87             try {
 88                 // workaround to avoid CNFE when the ReleaseEnv class is attempted to load in the shutdown hook from already closed classloader
 89                 Class.forName("oracle.code.onnx.foreign.OrtApi$ReleaseEnv");
 90             } catch (ClassNotFoundException e) {
 91                 throw new IllegalStateException(e);
 92             }
 93             try (var libStream = OnnxRuntime.class.getResourceAsStream(libResource)) {
 94                 var libFile = File.createTempFile("libonnxruntime", "");
 95                 Path libFilePath = libFile.toPath();
 96                 Files.copy(Objects.requireNonNull(libStream), libFilePath, StandardCopyOption.REPLACE_EXISTING);
 97                 System.load(libFilePath.toAbsolutePath().toString());
 98                 libFile.deleteOnExit();
 99             } catch (IOException e) {
100                 throw new RuntimeException(e);
101             }
102         }
103         LOG.log(System.Logger.Level.DEBUG, "ONNX Runtime initialized");
104     }
105 
106     private final MemorySegment runtimeAddress, ret, envAddress, defaultAllocatorAddress;
107 
108     private OnnxRuntime() {
109         var arena = Arena.ofAuto();
110         ret = arena.allocate(C_POINTER);
111         //  const OrtApi* ortPtr = OrtGetApiBase()->GetApi((uint32_t)apiVersion);
112         var apiBase = OrtApiBase.reinterpret(OrtGetApiBase(), arena, null);
113         runtimeAddress = OrtApi.reinterpret(OrtApiBase.GetApi.invoke(OrtApiBase.GetApi(apiBase), ORT_API_VERSION()), arena, null);
114         envAddress = retAddr(OrtApi.CreateEnv.invoke(OrtApi.CreateEnv(runtimeAddress), ORT_LOGGING_LEVEL_ERROR(), arena.allocateFrom(LOG_ID), ret));
115         defaultAllocatorAddress = retAddr(OrtApi.GetAllocatorWithDefaultOptions.invoke(OrtApi.GetAllocatorWithDefaultOptions(runtimeAddress), ret)).reinterpret(arena, null);
116         Runtime.getRuntime().addShutdownHook(new Thread(() -> {
117             LOG.log(System.Logger.Level.DEBUG, "ONNX Runtime released");
118             OrtApi.ReleaseEnv.invoke(OrtApi.ReleaseEnv(runtimeAddress), envAddress);
119         }));
120     }
121 
122     // @@@ temporary set public for ongoing experiments
123     public static List<Object> getInitValues(MethodHandles.Lookup lookup, SequencedCollection<FieldRef> initializers, SequencedCollection<Object> possibleReceivers) {
124         return initializers.stream().map(i -> {
125             try {
126                 Field initializerField = i.resolveToField(lookup);
127                 VarHandle handle = lookup.unreflectVarHandle(initializerField);
128                 if (initializerField.accessFlags().contains(AccessFlag.STATIC)) {
129                     return handle.get();
130                 } else {
131                     Class<?> initializerClass = initializerField.getDeclaringClass();
132                     return handle.get(possibleReceivers.stream().filter(initializerClass::isInstance).findFirst().orElseThrow());
133                 }
134             } catch (ReflectiveOperationException ex) {
135                 throw new RuntimeException(ex);
136             }
137         }).toList();
138     }
139 
140     public static <T> T execute(Supplier<T> codeLambda) {
141         return execute(MethodHandles.lookup(), codeLambda);
142     }
143 
144     public static <T> T execute(MethodHandles.Lookup l, Supplier<T> codeLambda) {
145         return execute(Arena.ofAuto(), l, codeLambda);
146     }
147 
148     private static void expandArg(Object val, Consumer<Tensor> args) {
149         switch (val) {
150             case CoreOp.Var<?> v -> expandArg(v.value(), args);
151             case Tensor t -> args.accept(t);
152             case Record r -> {
153                 for (var rc : r.getClass().getRecordComponents())
154                     try {
155                         expandArg(rc.getAccessor().invoke(r), args);
156                     } catch (ReflectiveOperationException e) {
157                         throw new IllegalStateException(e);
158                     }
159             }
160             // @@@ constant array last object must be consumed or the statically detected size and the actual size missmatch
161             case Object[] os -> {
162                 for (var o : os) {
163                     expandArg(o, args);
164                 }
165             }
166             default -> {
167             }
168         }
169     }
170 
171     public static <T> T execute(Arena arena, MethodHandles.Lookup l, Supplier<T> codeLambda) {
172         return execute(arena, l, codeLambda, null);
173     }
174 
175     public static <T> T execute(Arena arena, MethodHandles.Lookup l, Supplier<T> codeLambda, SessionOptions options) {
176         var q = Op.ofLambda(codeLambda).orElseThrow();
177 
178         var model = SESSION_CACHE.computeIfAbsent(codeLambda.getClass(), l, q, options);
179 
180         List<Tensor> arguments = Stream.concat(q.capturedValues().sequencedValues().stream(),
181                                                model.bypassedInitValues().stream())
182                 .mapMulti(OnnxRuntime::expandArg)
183                 .toList();
184         LOG.log(System.Logger.Level.DEBUG, "Running ONNX session " + codeLambda.getClass().getSimpleName().split("\\$")[0]);
185         List<Tensor> ret = model.session().run(arena, arguments);
186 
187         var lambdaOp = q.op();
188         CodeType type = lambdaOp.invokableSignature().returnType();
189         if (type instanceof ArrayType) {
190             return (T) ret.toArray(Tensor[]::new);
191         }
192         ClassType retType = ((ClassType) type).rawType();
193         if (retType.equals(TENSOR_RAW_TYPE)) {
194             return (T) ret.getFirst();
195         } else if (retType.equals(LIST_RAW_TYPE)) {
196             return (T) ret;
197         } else if (getRecordClass(l, retType) instanceof Class cls) {
198             try {
199                 return (T) cls.getConstructors()[0].newInstance(unflat(ret, (TupleType) model.returnType()));
200             } catch (ReflectiveOperationException e) {
201                 throw new IllegalStateException(e);
202             }
203         } else {
204             throw new UnsupportedOperationException("Unsupported return type: " + q.op().resultType());
205         }
206     }
207 
208     static Object[] unflat(List<Tensor> values, TupleType returnTupleType) {
209         var returnTypes = returnTupleType.componentTypes();
210         Object[] ret = new Object[returnTypes.size()];
211         for (int i = 0, j = 0; i < ret.length; i++) {
212             ret[i] = returnTypes.get(i) instanceof TupleType tt ? values.subList(j, j += tt.componentTypes().size()).toArray(Tensor[]::new) : values.get(j++);
213         }
214         return ret;
215     }
216 
217     static Class getRecordClass(MethodHandles.Lookup l, ClassType ct) {
218         try {
219             var t = ct.resolve(l);
220             while (t instanceof ParameterizedType pt) t = pt.getRawType();
221             if (t instanceof Class c && c.isRecord()) return c;
222         } catch (ReflectiveOperationException _) {
223         }
224         return null;
225     }
226 
227     public static OnnxRuntime getInstance() {
228         if (INSTANCE == null) {
229             INSTANCE = new OnnxRuntime();
230         }
231         return INSTANCE;
232     }
233 
234     private static MemorySegment autoShape(Arena arena, long[] shape, long elementsCount) {
235         int auto = -1;
236         long elCount = 1;
237         for (int i = 0; i < shape.length; i++) {
238             long dim = shape[i];
239             if (dim == -1) {
240                 if (auto == -1) {
241                     auto = i;
242                 } else {
243                     throw new IllegalArgumentException("Multiple automatic dimensions in shape");
244                 }
245             } else {
246                 elCount *= dim;
247             }
248         }
249         var ms = arena.allocateFrom(C_LONG_LONG, shape);
250         if (auto != -1) {
251             long autoDim = elementsCount / elCount;
252             ms.setAtIndex(C_LONG, auto, autoDim);
253             elCount *= autoDim;
254         }
255         if (elCount != elementsCount) {
256             throw new IllegalArgumentException("Tensor shape does not match data");
257         }
258         return ms;
259     }
260 
261     public List<Tensor> runOp(Arena arena, String opName, List<Tensor> inputValues, int numOutputs, Map<String, Object> attributes) {
262         var outputNames = IntStream.range(0, numOutputs).mapToObj(o -> "o" + o).toList();
263         var protoModel = OnnxProtoBuilder.buildModel(
264                 List.of(),
265                 IntStream.range(0, inputValues.size()).mapToObj(i -> OnnxProtoBuilder.tensorInfo("i" + i, inputValues.get(i).elementType().id)).toList(),
266                 List.of(OnnxProtoBuilder.node(
267                         opName,
268                         IntStream.range(0, inputValues.size()).mapToObj(i -> "i" + i).toList(),
269                         outputNames,
270                         attributes)),
271                 outputNames);
272         return createSession(arena, protoModel)
273                 .run(arena, inputValues);
274     }
275 
276     public List<Tensor> run(Arena arena, Block block, List<Tensor> inputValues, int initializers) {
277         var protoModel = OnnxProtoBuilder.buildModel(block, inputValues.subList(0, initializers));
278         return createSession(arena, protoModel)
279                 .run(arena, inputValues.subList(initializers, inputValues.size()));
280     }
281 
282     public Session createSession(Arena arena, String modelPath) {
283         return createSession(arena, modelPath, createSessionOptions(arena));
284     }
285 
286     public Session createSession(Arena arena, String modelPath, SessionOptions options) {
287         return new Session(arena, retAddr(OrtApi.CreateSession.invoke(OrtApi.CreateSession(runtimeAddress), envAddress, arena.allocateFrom(modelPath), options.sessionOptionsAddress, ret)));
288     }
289 
290     public Session createSession(Arena arena, byte[] model) {
291         return createSession(arena, model, createSessionOptions(arena));
292     }
293 
294     private Session createSession(Arena arena, byte[] model, SessionOptions options) {
295         return new Session(arena, retAddr(OrtApi.CreateSessionFromArray.invoke(OrtApi.CreateSessionFromArray(runtimeAddress), envAddress, arena.allocateFrom(ValueLayout.JAVA_BYTE, model), model.length, options.sessionOptionsAddress, ret)));
296     }
297 
298     public MemorySegment createTensor(Arena arena, MemorySegment flatData, Tensor.ElementType elementType, long[] shape) {
299         var allocatorInfo = retAddr(OrtApi.AllocatorGetInfo.invoke(OrtApi.AllocatorGetInfo(runtimeAddress), defaultAllocatorAddress, ret));
300         return retAddr(OrtApi.CreateTensorWithDataAsOrtValue.invoke(
301                 OrtApi.CreateTensorWithDataAsOrtValue(runtimeAddress),
302                 allocatorInfo,
303                 flatData, flatData.byteSize(),
304                 shape.length == 0 ? MemorySegment.NULL : autoShape(arena, shape, 8l * flatData.byteSize() / elementType.bitSize()), (long) shape.length,
305                 elementType.id,
306                 ret)).reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
307     }
308 
309     public MemorySegment createStringTensor(Arena arena, String[] data, long[] shape) {
310         var allocatorInfo = retAddr(OrtApi.AllocatorGetInfo.invoke(OrtApi.AllocatorGetInfo(runtimeAddress), defaultAllocatorAddress, ret));
311         MemorySegment flatDataPlacholder = arena.allocate(data.length * 24L); // @@@ calling of CreateTensorAsOrtValue crashes
312         var tensor = retAddr(OrtApi.CreateTensorWithDataAsOrtValue.invoke(
313                 OrtApi.CreateTensorWithDataAsOrtValue(runtimeAddress),
314                 allocatorInfo,
315                 flatDataPlacholder, flatDataPlacholder.byteSize(),
316                 autoShape(arena, shape, data.length), (long) shape.length,
317                 Tensor.ElementType.STRING.id,
318                 ret));
319         for (int i = 0; i < data.length; i++) {
320             checkStatus(OrtApi.FillStringTensorElement.invoke(
321                     OrtApi.FillStringTensorElement(runtimeAddress),
322                     tensor,
323                     arena.allocateFrom(data[i]),
324                     (long)i));
325         }
326         return tensor.reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
327     }
328 
329     public Tensor.ElementType tensorElementType(MemorySegment tensorAddr) {
330         var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
331         return Tensor.ElementType.fromOnnxId(retInt(OrtApi.GetTensorElementType.invoke(OrtApi.GetTensorElementType(runtimeAddress), infoAddr, ret)));
332     }
333 
334     public long[] tensorShape(MemorySegment tensorAddr) {
335         try (var arena = Arena.ofConfined()) {
336             var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
337             long dims = retLong(OrtApi.GetDimensionsCount.invoke(OrtApi.GetDimensionsCount(runtimeAddress), infoAddr, ret));
338             var shape = arena.allocate(C_LONG_LONG, dims);
339             checkStatus(OrtApi.GetDimensions.invoke(OrtApi.GetDimensions(runtimeAddress), infoAddr, shape, dims));
340             return shape.toArray(C_LONG_LONG);
341         }
342     }
343 
344     public long tensorShapeElementCount(MemorySegment tensorAddr) {
345         var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
346         return retLong(OrtApi.GetTensorShapeElementCount.invoke(OrtApi.GetTensorShapeElementCount(runtimeAddress), infoAddr, ret));
347     }
348 
349     public String stringTensorElement(MemorySegment tensorAddr, long elementIndex) {
350         try (var arena = Arena.ofConfined()) {
351             long elemLen = retLong(OrtApi.GetStringTensorElementLength.invoke(OrtApi.GetStringTensorElementLength(runtimeAddress), tensorAddr, elementIndex, ret));
352             var buf = arena.allocate(ValueLayout.JAVA_BYTE, elemLen + 1);
353             checkStatus(OrtApi.GetStringTensorElement.invoke(OrtApi.GetStringTensorElement(runtimeAddress), tensorAddr, elemLen, elementIndex, buf));
354             return buf.getString(0);
355         }
356     }
357 
358     public MemorySegment tensorData(MemorySegment tensorAddr) {
359         var infoAddr = retAddr(OrtApi.GetTensorTypeAndShape.invoke(OrtApi.GetTensorTypeAndShape(runtimeAddress), tensorAddr, ret));
360         long size = retLong(OrtApi.GetTensorShapeElementCount.invoke(OrtApi.GetTensorShapeElementCount(runtimeAddress), infoAddr, ret))
361                 * Tensor.ElementType.fromOnnxId(retInt(OrtApi.GetTensorElementType.invoke(OrtApi.GetTensorElementType(runtimeAddress), infoAddr, ret))).bitSize() / 8;
362         return retAddr(OrtApi.GetTensorMutableData.invoke(OrtApi.GetTensorMutableData(runtimeAddress), tensorAddr, ret))
363                 .reinterpret(size);
364     }
365 
366     public SessionOptions createSessionOptions(Arena arena) {
367         return new SessionOptions(retAddr(OrtApi.CreateSessionOptions.invoke(OrtApi.CreateSessionOptions(runtimeAddress), ret))
368                 .reinterpret(arena, opts -> OrtApi.ReleaseSessionOptions.invoke(OrtApi.ReleaseSessionOptions(runtimeAddress), opts)));
369     }
370 
371     public void appendExecutionProvider(Arena arena, SessionOptions sessionOptions, OnnxProvider provider) {
372         ExecutionProviderOptions executionProviderOptions = toNativeOptions(arena, provider);
373 
374         MemorySegment funcPtr = OrtApi.SessionOptionsAppendExecutionProvider(runtimeAddress);
375         var status = OrtApi.SessionOptionsAppendExecutionProvider.invoke(
376                 funcPtr,
377                 sessionOptions.getSessionOptionsAddress(),
378                 arena.allocateFrom(provider.name()),
379                 executionProviderOptions.keySegment, executionProviderOptions.valSegment, executionProviderOptions.size);
380         checkStatus(status);
381     }
382 
383     public void appendExecutionProvider_V2(Arena arena, SessionOptions sessionOptions, OnnxProvider provider) {
384         MemorySegment getEpDevicesFn = OrtApi.GetEpDevices(runtimeAddress);
385         MemorySegment devicesOut = arena.allocate(ValueLayout.ADDRESS);
386         MemorySegment countOut = arena.allocate(ValueLayout.JAVA_LONG);
387         checkStatus(OrtApi.GetEpDevices.invoke(getEpDevicesFn, envAddress, devicesOut, countOut));
388         long numDevices = countOut.get(ValueLayout.JAVA_LONG, 0);
389 
390         MemorySegment devicesBasePtr = devicesOut.get(ValueLayout.ADDRESS, 0);
391         MemorySegment devicesArr = (numDevices > 0 && devicesBasePtr != MemorySegment.NULL)
392                 ? devicesBasePtr.reinterpret(numDevices * ValueLayout.ADDRESS.byteSize())
393                 : MemorySegment.NULL;
394 
395         String target = provider.name();
396         var matches = new java.util.ArrayList<MemorySegment>();
397         if (devicesArr != MemorySegment.NULL) {
398             MemorySegment epNameFn = OrtApi.EpDevice_EpName(runtimeAddress);
399             for (int j = 0; j < numDevices; j++) {
400                 MemorySegment dev = devicesArr.getAtIndex(ValueLayout.ADDRESS, j);
401                 MemorySegment cstr = OrtApi.EpDevice_EpName.invoke(epNameFn, dev);
402                 String epName = (cstr == MemorySegment.NULL) ? "" : cstr.getString(0);
403                 if (target.equals(epName)) matches.add(dev);
404             }
405         }
406 
407         if (matches.isEmpty()) {
408             appendExecutionProvider(arena, sessionOptions, provider);
409         } else {
410             ExecutionProviderOptions executionProviderOptions = toNativeOptions(arena, provider);
411 
412             long deviceCount = matches.size();
413             MemorySegment deviceArrayPtr = arena.allocate(ValueLayout.ADDRESS, deviceCount);
414             for (int j = 0; j < matches.size(); j++)
415                 deviceArrayPtr.setAtIndex(ValueLayout.ADDRESS, j, matches.get(j));
416 
417             MemorySegment functionPtr = OrtApi.SessionOptionsAppendExecutionProvider_V2(runtimeAddress);
418             checkStatus(OrtApi.SessionOptionsAppendExecutionProvider_V2.invoke(
419                     functionPtr,
420                     sessionOptions.getSessionOptionsAddress(),
421                     envAddress,
422                     deviceArrayPtr, deviceCount,
423                     executionProviderOptions.keySegment(), executionProviderOptions.valSegment(), executionProviderOptions.size()
424             ));
425         }
426     }
427 
428     private static ExecutionProviderOptions toNativeOptions(Arena arena, OnnxProvider provider) {
429         var providerOptions = provider.options();
430         MemorySegment keySegment = MemorySegment.NULL;
431         MemorySegment valSegment = MemorySegment.NULL;
432         int size = 0;
433 
434         if (Objects.nonNull(providerOptions) && !providerOptions.isEmpty()) {
435             size = providerOptions.size();
436             keySegment = arena.allocate(ValueLayout.ADDRESS, size);
437             valSegment = arena.allocate(ValueLayout.ADDRESS, size);
438             int i = 0;
439 
440             for (Map.Entry<String, String> e : providerOptions.entrySet()) {
441                 keySegment.setAtIndex(ValueLayout.ADDRESS, i, arena.allocateFrom(e.getKey()));
442                 valSegment.setAtIndex(ValueLayout.ADDRESS, i, arena.allocateFrom(e.getValue()));
443                 i++;
444             }
445         }
446         ExecutionProviderOptions executionProviderOptions = new ExecutionProviderOptions(keySegment, valSegment, size);
447         return executionProviderOptions;
448     }
449 
450     private record ExecutionProviderOptions(MemorySegment keySegment, MemorySegment valSegment, int size) {}
451 
452     private MemorySegment retAddr(MemorySegment res) {
453         checkStatus(res);
454         return ret.get(C_POINTER, 0);
455     }
456 
457     private int retInt(MemorySegment res) {
458         checkStatus(res);
459         return ret.get(C_INT, 0);
460     }
461 
462     private long retLong(MemorySegment res) {
463         checkStatus(res);
464         return ret.get(C_LONG_LONG, 0);
465     }
466 
467     private String retString(MemorySegment res) {
468         return retAddr(res).reinterpret(Long.MAX_VALUE)
469                 .getString(0);
470     }
471 
472     private void checkStatus(MemorySegment status) {
473         try {
474             if (!status.equals(MemorySegment.NULL)) {
475                 status = status.reinterpret(Long.MAX_VALUE);
476                 if (status.get(C_INT, 0) != 0) {
477                     throw new RuntimeException(status.getString(C_INT.byteSize()));
478                 }
479             }
480         } finally {
481             OrtApi.ReleaseStatus.invoke(OrtApi.ReleaseStatus(runtimeAddress), status);
482         }
483     }
484 
485     record SessionWithReturnType(Session session, CodeType returnType, List<Object> bypassedInitValues) {
486     }
487 
488     static class CachedSessionClassValue extends ClassValue<SessionWithReturnType> {
489 
490         private MethodHandles.Lookup l;
491         private Quoted<JavaOp.LambdaOp> q;
492         private SessionOptions options;
493 
494         // Static helper for cache with options
495         protected SessionWithReturnType computeIfAbsent(
496                 Class<?> lambdaClass, MethodHandles.Lookup l, Quoted<JavaOp.LambdaOp> q, SessionOptions options) {
497             try {
498                 this.l = l;
499                 this.q = q;
500                 this.options = options;
501                 // not very nice way to pass additional arguments to computeValue method
502                 return get(lambdaClass);
503             } finally {
504                 this.l = null;
505                 this.q = null;
506                 this.options = null;
507             }
508         }
509 
510         @Override
511         protected SessionWithReturnType computeValue(Class<?> type) {
512             OnnxTransformer.ModuleAndInitializers mi = OnnxTransformer.transform(l, q);
513 
514             String domainName = type.getSimpleName().split("\\$")[0];
515             boolean bypassInits = options != null && options.bypassInitilizers;
516             List<Object> initValues = getInitValues(l, mi.initializers(), q.capturedValues().sequencedValues());
517             LOG.log(System.Logger.Level.DEBUG, "Building ONNX binary " + domainName);
518             byte[] protobufModel = OnnxProtoBuilder.buildModel(domainName, mi.module(), bypassInits ? List.of() : initValues);
519 
520             if (DEBUG) {
521                 System.out.println(mi.module().toText());
522 //                System.out.println(OnnxModel.readFrom(protobufModel).toText());
523                 try {
524                     var export = Path.of(domainName + ".onnx");
525                     Files.write(export, protobufModel);
526                     System.out.println("Onnx model exported to: " + export.toAbsolutePath());
527                 } catch (IOException _) {
528                 }
529             }
530 
531             LOG.log(System.Logger.Level.DEBUG, "Creating ONNX session " + domainName);
532             // cached session must be created under its own auto arena
533             Session session = (options != null) ?
534                     getInstance().createSession(Arena.ofAuto(), protobufModel, options) :
535                     getInstance().createSession(Arena.ofAuto(), protobufModel);
536 
537             return new SessionWithReturnType(
538                     session,
539                     mi.module().functionTable().lastEntry().getValue().invokableSignature().returnType(),
540                     bypassInits ? initValues : List.of());
541 
542 
543         }
544     }
545 
546     public final class Session {
547 
548         private final MemorySegment sessionAddress;
549 
550         private Session(Arena arena, MemorySegment sessionAddress) {
551             this.sessionAddress = sessionAddress.reinterpret(arena,
552                     session -> OrtApi.ReleaseSession.invoke(OrtApi.ReleaseSession(runtimeAddress), session));
553         }
554 
555         public int getNumberOfInputs() {
556             return retInt(OrtApi.SessionGetInputCount.invoke(OrtApi.SessionGetInputCount(runtimeAddress), sessionAddress, ret));
557         }
558 
559         public String getInputName(int inputIndex) {
560             return retString(OrtApi.SessionGetInputName.invoke(OrtApi.SessionGetInputName(runtimeAddress), sessionAddress, inputIndex, defaultAllocatorAddress, ret));
561         }
562 
563         public int getNumberOfOutputs() {
564             return retInt(OrtApi.SessionGetOutputCount.invoke(OrtApi.SessionGetOutputCount(runtimeAddress), sessionAddress, ret));
565         }
566 
567         public String getOutputName(int inputIndex) {
568             return retString(OrtApi.SessionGetOutputName.invoke(OrtApi.SessionGetOutputName(runtimeAddress), sessionAddress, inputIndex, defaultAllocatorAddress, ret));
569         }
570 
571         // @@@ only tensors are supported yet
572         public List<Tensor> run(Arena arena, List<Tensor> inputValues) {
573             var runOptions = MemorySegment.NULL;
574             int inputLen = getNumberOfInputs();
575             int outputLen = getNumberOfOutputs();
576             var inputNames = arena.allocate(C_POINTER, inputLen);
577             var inputs = arena.allocate(C_POINTER, inputLen);
578             long index = 0;
579             for (int i = 0; i < inputLen; i++) {
580                 inputNames.setAtIndex(C_POINTER, index, arena.allocateFrom(getInputName(i)));
581                 inputs.setAtIndex(C_POINTER, index++, inputValues.get(i).tensorAddr);
582             }
583             var outputNames = arena.allocate(C_POINTER, outputLen);
584             var outputs = arena.allocate(C_POINTER, outputLen);
585             for (int i = 0; i < outputLen; i++) {
586                 outputNames.setAtIndex(C_POINTER, i, arena.allocateFrom(getOutputName(i)));
587                 outputs.setAtIndex(C_POINTER, i, MemorySegment.NULL);
588             }
589             checkStatus(OrtApi.Run.invoke(OrtApi.Run(runtimeAddress), sessionAddress, runOptions, inputNames, inputs, (long) inputLen, outputNames, (long) outputLen, outputs));
590             var retArr = new Tensor[outputLen];
591             for (int i = 0; i < outputLen; i++) {
592                 var tensorAddr = outputs.getAtIndex(C_POINTER, i)
593                         .reinterpret(arena, value -> OrtApi.ReleaseValue.invoke(OrtApi.ReleaseValue(runtimeAddress), value));
594                 retArr[i] = new Tensor(tensorData(tensorAddr).reinterpret(arena, null),
595                         tensorAddr);
596             }
597             return List.of(retArr);
598         }
599     }
600 
601     public final class SessionOptions {
602 
603         private final MemorySegment sessionOptionsAddress;
604 
605         boolean bypassInitilizers = false;
606 
607         public SessionOptions(MemorySegment sessionOptionsAddress) {
608             this.sessionOptionsAddress = sessionOptionsAddress;
609             setInterOpNumThreads(1);
610         }
611 
612         public void registerCustomOpsLibrary(MemorySegment path) {
613             checkStatus(OrtApi.RegisterCustomOpsLibrary_V2.invoke(OrtApi.RegisterCustomOpsLibrary_V2(runtimeAddress), sessionOptionsAddress, path));
614         }
615 
616         public void setInterOpNumThreads(int numThreads) {
617             checkStatus(OrtApi.SetInterOpNumThreads.invoke(OrtApi.SetInterOpNumThreads(runtimeAddress), sessionOptionsAddress, numThreads));
618         }
619 
620         public void setIntraOpNumThreads(int numThreads) {
621             checkStatus(OrtApi.SetIntraOpNumThreads.invoke(OrtApi.SetIntraOpNumThreads(runtimeAddress), sessionOptionsAddress, numThreads));
622         }
623 
624         public void setSessionExecutionMode(int executionMode) {
625             checkStatus(OrtApi.SetSessionExecutionMode.invoke(OrtApi.SetSessionExecutionMode(runtimeAddress), sessionOptionsAddress, executionMode));
626         }
627 
628         public MemorySegment getSessionOptionsAddress() {
629             return sessionOptionsAddress;
630         }
631 
632         public void bypassInitizers() {
633             bypassInitilizers = true;
634         }
635     }
636 }