1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx;
27
28 import java.lang.foreign.ValueLayout;
29 import java.util.List;
30 import java.util.Optional;
31 import jdk.incubator.code.Quotable;
32 import oracle.code.onnx.ir.OnnxOps;
33
34 class ExplicitOnnxOperators {
35
36 // Explicit constant operators
37
38 public static Tensor<Long> Constant(
39 Long c) {
40 return OnnxOperators.Constant(
41 Optional.of(c),Optional.empty(), Optional.empty(), Optional.empty(),
42 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
43 }
44
45 public static Tensor<Long> Constant(
46 long[] c) {
47 return OnnxOperators.Constant(
48 Optional.empty(),Optional.empty(), Optional.empty(), Optional.empty(),
49 Optional.empty(), Optional.of(c), Optional.empty(), Optional.empty());
50 }
51
52 public static Tensor<Float> Constant(
53 Float c) {
54 return OnnxOperators.Constant(
55 Optional.empty(),Optional.empty(), Optional.empty(), Optional.of(c),
56 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
57 }
58
59 public static Tensor<Float> Constant(
60 float[] c) {
61 return OnnxOperators.Constant(
62 Optional.empty(),Optional.of(c), Optional.empty(), Optional.empty(),
63 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
64 }
65
66 public static Tensor<Integer> Constant(
67 String c) {
68 return OnnxOperators.Constant(
69 Optional.empty(),Optional.empty(), Optional.empty(), Optional.empty(),
70 Optional.of(c), Optional.empty(), Optional.empty(), Optional.empty());
71 }
72
73 public static Tensor<Integer> Constant(
74 String[] c) {
75 return OnnxOperators.Constant(
76 Optional.empty(),Optional.empty(), Optional.of(c), Optional.empty(),
77 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
78 }
79
80 // @@@ Constants for value - TENSOR and sparse_value - SPARSE_TENSOR
81
82
83 public interface IfBody<T> extends Quotable {
84 T invoke();
85 }
86
87 public static <T> T If(Tensor<Boolean> cond, IfBody<T> thenBody, IfBody<T> elseBody) {
88 return booleanValue(cond) ? thenBody.invoke() : elseBody.invoke();
89 }
90
91 public record LoopResult<T>(Tensor<Boolean> cond, T output) {}
92 public interface LoopBody<T> extends Quotable {
93 LoopResult<T> invoke(Tensor<Long> i, Tensor<Boolean> cond, T input);
94 }
95
96 public static <T> T Loop(Tensor<Long> max, Tensor<Boolean> cond, T values, LoopBody<T> loopBody) {
97 long m = max.data().get(ValueLayout.JAVA_LONG, 0);
98 for (var i = Tensor.ofScalar(0l); longValue(i) < m && booleanValue(cond); set(i, longValue(i) + 1)) {
99 LoopResult<T> ret = loopBody.invoke(i, cond, values);
100 cond = ret.cond();
101 values = ret.output();
102 }
103 return values;
104 }
105
106 // @@@ this should be generated from contrib operators
107
108 public record GroupQueryAttention<T>(Tensor<T> output, Tensor<T> present_key, Tensor<T> present_value) { }
109 public static <T, M> GroupQueryAttention<T> GroupQueryAttention(Tensor<T> query, java.util.Optional<Tensor<T>> key, java.util.Optional<Tensor<T>> value, java.util.Optional<Tensor<T>> past_key, java.util.Optional<Tensor<T>> past_value, Tensor<M> seqlens_k, Tensor<M> total_sequence_length, java.util.Optional<Tensor<T>> cos_cache, java.util.Optional<Tensor<T>> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
110 Object result = OnnxInterpreter.interpret(OnnxOps.GroupQueryAttention.class, List.of(query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache), List.of(do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale));
111 Object[] resultArray = (Object[]) result;
112 return new GroupQueryAttention<>((Tensor<T>)resultArray[0], (Tensor<T>)resultArray[1], (Tensor<T>)resultArray[2]);
113 }
114
115 public static <T1, T2, T3, T4> Tensor<T1> MatMulNBits(Tensor<T1> a, Tensor<T2> b, Tensor<T1> scales, java.util.Optional<Tensor<T3>> zero_points, java.util.Optional<Tensor<T4>> g_idx, java.util.Optional<Tensor<T1>> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
116 Object result = OnnxInterpreter.interpret(OnnxOps.MatMulNBits.class, List.of(a, b, scales, zero_points, g_idx, bias), List.of(K, N, accuracy_level, bits, block_size));
117 return (Tensor<T1>)result;
118 }
119
120 public record SkipSimplifiedLayerNormalization<T>(Tensor<T> output, Tensor<Float> mean, Tensor<Float> inv_std_var, Tensor<Float> input_skip_bias_sum) { }
121 public static <T> SkipSimplifiedLayerNormalization<T> SkipSimplifiedLayerNormalization(Tensor<T> input, Tensor<T> skip, Tensor<T> gamma, java.util.Optional<Tensor<T>> bias, java.util.Optional<Float> epsilon) {
122 Object result = OnnxInterpreter.interpret(OnnxOps.SkipSimplifiedLayerNormalization.class, List.of(input, skip, gamma, bias), List.of(epsilon));
123 Object[] resultArray = (Object[]) result;
124 return new SkipSimplifiedLayerNormalization<>((Tensor<T>)resultArray[0], (Tensor<Float>)resultArray[1], (Tensor<Float>)resultArray[2], (Tensor<Float>)resultArray[3]);
125 }
126
127 // @@@ move to Tensor API
128
129 private static boolean booleanValue(Tensor<Boolean> t) {
130 return t.data().get(ValueLayout.JAVA_BOOLEAN, 0);
131 }
132
133 private static long longValue(Tensor<Long> t) {
134 return t.data().get(ValueLayout.JAVA_LONG, 0);
135 }
136
137 private static void set(Tensor<Long> t, long value) {
138 t.data().set(ValueLayout.JAVA_LONG, 0, value);
139 }
140 }