1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.ValueLayout;
 29 import java.util.List;
 30 import java.util.Optional;
 31 import jdk.incubator.code.Quotable;
 32 import oracle.code.onnx.ir.OnnxOps;
 33 
 34 class ExplicitOnnxOperators {
 35 
 36     // Explicit constant operators
 37 
 38     public static Tensor<Long> Constant(
 39             Long c) {
 40         return OnnxOperators.Constant(
 41                 Optional.of(c),Optional.empty(), Optional.empty(), Optional.empty(),
 42                 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
 43     }
 44 
 45     public static Tensor<Long> Constant(
 46             long[] c) {
 47         return OnnxOperators.Constant(
 48                 Optional.empty(),Optional.empty(), Optional.empty(), Optional.empty(),
 49                 Optional.empty(), Optional.of(c), Optional.empty(), Optional.empty());
 50     }
 51 
 52     public static Tensor<Float> Constant(
 53             Float c) {
 54         return OnnxOperators.Constant(
 55                 Optional.empty(),Optional.empty(), Optional.empty(), Optional.of(c),
 56                 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
 57     }
 58 
 59     public static Tensor<Float> Constant(
 60             float[] c) {
 61         return OnnxOperators.Constant(
 62                 Optional.empty(),Optional.of(c), Optional.empty(), Optional.empty(),
 63                 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
 64     }
 65 
 66     public static Tensor<Integer> Constant(
 67             String c) {
 68         return OnnxOperators.Constant(
 69                 Optional.empty(),Optional.empty(), Optional.empty(), Optional.empty(),
 70                 Optional.of(c), Optional.empty(), Optional.empty(), Optional.empty());
 71     }
 72 
 73     public static Tensor<Integer> Constant(
 74             String[] c) {
 75         return OnnxOperators.Constant(
 76                 Optional.empty(),Optional.empty(), Optional.of(c), Optional.empty(),
 77                 Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
 78     }
 79 
 80     // @@@ Constants for value - TENSOR and sparse_value - SPARSE_TENSOR
 81 
 82 
 83     public interface IfBody<T> extends Quotable {
 84         T invoke();
 85     }
 86 
 87     public static <T> T If(Tensor<Boolean> cond, IfBody<T> thenBody, IfBody<T> elseBody) {
 88         return booleanValue(cond) ? thenBody.invoke() : elseBody.invoke();
 89     }
 90 
 91     public record LoopResult<T>(Tensor<Boolean> cond, T output) {}
 92     public interface LoopBody<T> extends Quotable {
 93         LoopResult<T> invoke(Tensor<Long> i, Tensor<Boolean> cond, T input);
 94     }
 95 
 96     public static <T> T Loop(Tensor<Long> max, Tensor<Boolean> cond, T values, LoopBody<T> loopBody) {
 97         long m = max.data().get(ValueLayout.JAVA_LONG, 0);
 98         for (var i = Tensor.ofScalar(0l); longValue(i) < m && booleanValue(cond); set(i, longValue(i) + 1)) {
 99             LoopResult<T> ret = loopBody.invoke(i, cond, values);
100             cond = ret.cond();
101             values = ret.output();
102         }
103         return values;
104     }
105 
106     // @@@ this should be generated from contrib operators
107 
108     public record GroupQueryAttention<T>(Tensor<T> output, Tensor<T> present_key, Tensor<T> present_value) { }
109     public static <T, M> GroupQueryAttention<T> GroupQueryAttention(Tensor<T> query, java.util.Optional<Tensor<T>> key, java.util.Optional<Tensor<T>> value, java.util.Optional<Tensor<T>> past_key, java.util.Optional<Tensor<T>> past_value, Tensor<M> seqlens_k, Tensor<M> total_sequence_length, java.util.Optional<Tensor<T>> cos_cache, java.util.Optional<Tensor<T>> sin_cache, java.util.Optional<Long> do_rotary, long kv_num_heads, java.util.Optional<Long> local_window_size, long num_heads, java.util.Optional<Long> rotary_interleaved, java.util.Optional<Float> scale) {
110         Object result = OnnxInterpreter.interpret(OnnxOps.GroupQueryAttention.class, List.of(query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache), List.of(do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale));
111         Object[] resultArray = (Object[]) result;
112         return new GroupQueryAttention<>((Tensor<T>)resultArray[0], (Tensor<T>)resultArray[1], (Tensor<T>)resultArray[2]);
113     }
114 
115     public static <T1, T2, T3, T4> Tensor<T1> MatMulNBits(Tensor<T1> a, Tensor<T2> b, Tensor<T1> scales, java.util.Optional<Tensor<T3>> zero_points, java.util.Optional<Tensor<T4>> g_idx, java.util.Optional<Tensor<T1>> bias, long K, long N, java.util.Optional<Long> accuracy_level, long bits, long block_size) {
116         Object result = OnnxInterpreter.interpret(OnnxOps.MatMulNBits.class, List.of(a, b, scales, zero_points, g_idx, bias), List.of(K, N, accuracy_level, bits, block_size));
117         return (Tensor<T1>)result;
118     }
119 
120     public record SkipSimplifiedLayerNormalization<T>(Tensor<T> output, Tensor<Float> mean, Tensor<Float> inv_std_var, Tensor<Float> input_skip_bias_sum) { }
121     public static <T> SkipSimplifiedLayerNormalization<T> SkipSimplifiedLayerNormalization(Tensor<T> input, Tensor<T> skip, Tensor<T> gamma, java.util.Optional<Tensor<T>> bias, java.util.Optional<Float> epsilon) {
122         Object result = OnnxInterpreter.interpret(OnnxOps.SkipSimplifiedLayerNormalization.class, List.of(input, skip, gamma, bias), List.of(epsilon));
123         Object[] resultArray = (Object[]) result;
124         return new SkipSimplifiedLayerNormalization<>((Tensor<T>)resultArray[0], (Tensor<Float>)resultArray[1], (Tensor<Float>)resultArray[2], (Tensor<Float>)resultArray[3]);
125     }
126 
127     // @@@ move to Tensor API
128 
129     private static boolean booleanValue(Tensor<Boolean> t) {
130         return t.data().get(ValueLayout.JAVA_BOOLEAN, 0);
131     }
132 
133     private static long longValue(Tensor<Long> t) {
134         return t.data().get(ValueLayout.JAVA_LONG, 0);
135     }
136 
137     private static void set(Tensor<Long> t, long value) {
138         t.data().set(ValueLayout.JAVA_LONG, 0, value);
139     }
140 }