1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx;
 25 
 26 import java.lang.foreign.ValueLayout;
 27 import java.util.List;
 28 import java.util.Optional;
 29 import jdk.incubator.code.CodeReflection;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.Test;
 32 
 33 import static java.util.Optional.empty;
 34 import static oracle.code.onnx.OnnxOperators.*;
 35 import static oracle.code.onnx.OnnxRuntime.execute;
 36 
 37 public class SimpleTest {
 38 
 39     @CodeReflection
 40     public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
 41         return Add(a, b);
 42     }
 43 
 44     @Test
 45     public void testAdd() throws Exception {
 46         var a = Tensor.ofFlat(1f, 2, 3);
 47         assertEquals(
 48                 add(a, a),
 49                 execute(() -> add(a, a)));
 50     }
 51 
 52     @CodeReflection
 53     public Tensor<Float> sub(Tensor<Float> a, Tensor<Float> b) {
 54         return Sub(a, b);
 55     }
 56 
 57     @Test
 58     public void testSub() throws Exception {
 59         var b = Tensor.ofFlat(6f, 5, 4);
 60         var a = Tensor.ofFlat(1f, 2, 3);
 61         assertEquals(
 62                 sub(a, b),
 63                 execute(() -> sub(a, b)));
 64     }
 65 
 66     @CodeReflection
 67     public Tensor<Float> fconstant() {
 68         return Constant(-1f);
 69     }
 70 
 71     @Test
 72     public void testFconstant() throws Exception {
 73         // tests the numbers are encoded correctly
 74         var expected = Tensor.ofScalar(-1f);
 75         assertEquals(expected, fconstant());
 76         assertEquals(expected, execute(() -> fconstant()));
 77     }
 78 
 79     @CodeReflection
 80     public Tensor<Float> fconstants() {
 81         return Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
 82     }
 83 
 84     @Test
 85     public void testFconstants() throws Exception {
 86         // tests the numbers are encoded correctly
 87         var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
 88         assertEquals(expected, fconstants());
 89         assertEquals(expected, execute(() -> fconstants()));
 90     }
 91 
 92     @CodeReflection
 93     public Tensor<Long> lconstant() {
 94         return Constant(-1l);
 95     }
 96 
 97     @Test
 98     public void testLconstant() throws Exception {
 99         // tests the numbers are encoded correctly
100         var expected = Tensor.ofScalar(-1l);
101         assertEquals(expected, lconstant());
102         assertEquals(expected, execute(() -> lconstant()));
103     }
104 
105     @CodeReflection
106     public Tensor<Long> lconstants() {
107         return Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
108     }
109 
110     @Test
111     public void testLconstants() throws Exception {
112         // tests the numbers are encoded correctly
113         var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
114         assertEquals(expected, lconstants());
115         assertEquals(expected, execute(() -> lconstants()));
116     }
117 
118     @CodeReflection
119     public Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
120         return Shape(Reshape(data, shape, empty()), empty(), empty());
121     }
122 
123     @Test
124     public void testReshapeAndShape() throws Exception {
125         var data = Tensor.ofFlat(1f, 2, 3, 4, 5, 6, 7, 8);
126         var shape = Tensor.ofFlat(2l, 2, 2);
127         assertEquals(
128                 reshapeAndShape(data, shape),
129                 execute(() -> reshapeAndShape(data, shape)));
130     }
131 
132     @CodeReflection
133     public Tensor<Long> indicesOfMaxPool(Tensor<Float> x) {
134         // testing secondary output
135         return MaxPool(x, empty(), empty(), empty(), empty(), empty(), empty(),  new long[]{2}).Indices();
136     }
137 
138     @Test
139     public void testIndicesOfMaxPool() throws Exception {
140         var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
141         assertEquals(
142                 indicesOfMaxPool(x),
143                 execute(() -> indicesOfMaxPool(x)));
144     }
145 
146     @CodeReflection
147     public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2) {
148         return Concat(List.of(input1, input2), 0);
149     }
150 
151     @Test
152     public void testConcat() throws Exception {
153         var input1 = Tensor.ofFlat(1f, 2, 3);
154         var input2 = Tensor.ofFlat(4f, 5);
155         assertEquals(
156                 concat(input1, input2),
157                 execute(()-> concat(input1, input2)));
158     }
159 
160     @CodeReflection
161     public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
162         return Split(input, Optional.of(split), empty(), empty()).get(0);
163     }
164 
165     @Test
166     public void testSplit() throws Exception {
167         var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
168         var split = Tensor.ofFlat(5l);
169         assertEquals(
170                 split(input, split),
171                 execute(()-> split(input, split)));
172     }
173 
174     @CodeReflection
175     public Tensor<Float> ifConst(Tensor<Boolean> cond) {
176         return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f))).get(0);
177     }
178 
179     @CodeReflection
180     public List<Tensor<Float>> ifConstList(Tensor<Boolean> cond) {
181         return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f)));
182     }
183 
184     public record SingleValueTuple<T>(T val) {}
185 
186     @CodeReflection
187     public SingleValueTuple<Tensor<Float>> ifConstRecord(Tensor<Boolean> cond) {
188         return If(cond, () -> new SingleValueTuple<>(Constant(1f)), () -> new SingleValueTuple<>(Constant(-1f)));
189     }
190 
191     @Test
192     public void testIfConst() throws Exception {
193         var condFalse = Tensor.ofScalar(false);
194         var expFalse = Tensor.ofScalar(-1f);
195         var condTrue = Tensor.ofScalar(true);
196         var expTrue = Tensor.ofScalar(1f);
197 
198         assertEquals(expFalse, ifConst(condFalse));
199         assertEquals(expFalse, execute(() -> ifConst(condFalse)));
200 
201         assertEquals(expTrue, ifConst(condTrue));
202         assertEquals(expTrue, execute(() -> ifConst(condTrue)));
203 
204         assertEquals(expFalse, execute(() -> ifConstList(condFalse)).get(0));
205         assertEquals(expTrue, execute(() -> ifConstList(condTrue)).get(0));
206 
207         assertEquals(expFalse, execute(() -> ifConstRecord(condFalse)).val());
208         assertEquals(expTrue, execute(() -> ifConstRecord(condTrue)).val());
209     }
210 
211     @CodeReflection
212     public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
213         var falseValue = Constant(-1f);
214         return If(cond, () -> Identity(trueValue), () -> Identity(falseValue));
215     }
216 
217     @Test
218     public void testIfCapture() throws Exception {
219         var condFalse = Tensor.ofScalar(false);
220         var expFalse = Tensor.ofScalar(-1f);
221         var condTrue = Tensor.ofScalar(true);
222         var expTrue = Tensor.ofScalar(1f);
223 
224         assertEquals(expFalse, ifCapture(condFalse, expTrue));
225         assertEquals(expFalse, execute(() -> ifCapture(condFalse, expTrue)));
226 
227         assertEquals(expTrue, ifCapture(condTrue, expTrue));
228         assertEquals(expTrue, execute(() -> ifCapture(condTrue, expTrue)));
229     }
230 
231     final Tensor<Float> initialized = Tensor.ofFlat(42f);
232 
233     @CodeReflection
234     public Tensor<Float> initialized() {
235         return Identity(initialized);
236     }
237 
238     @Test
239     public void testInitialized() throws Exception {
240 
241         assertEquals(initialized(),
242                      execute(() -> initialized()));
243     }
244 
245     final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
246     final Tensor<Float> initialized3 = Tensor.ofFlat(-1f);
247     final Tensor<Float> initialized4 = Tensor.ofFlat(-99f);
248 
249     @CodeReflection
250     public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
251         return If(cond1,
252                 () -> If(cond2,
253                         () -> List.of(Identity(initialized)),
254                         () -> List.of(Identity(initialized2))),
255                 () -> If(cond2,
256                         () -> List.of(Identity(initialized3)),
257                         () -> List.of(Identity(initialized4)))).get(0);
258     }
259 
260     @Test
261     public void testIfInitialized() throws Exception {
262         var condFalse = Tensor.ofScalar(false);
263         var condTrue = Tensor.ofScalar(true);
264 
265         assertEquals(initialized, ifInitialized(condTrue, condTrue));
266         assertEquals(initialized, execute(() -> ifInitialized(condTrue, condTrue)));
267         assertEquals(initialized2, ifInitialized(condTrue, condFalse));
268         assertEquals(initialized2, execute(() -> ifInitialized(condTrue, condFalse)));
269         assertEquals(initialized3, ifInitialized(condFalse, condTrue));
270         assertEquals(initialized3, execute(() -> ifInitialized(condFalse, condTrue)));
271         assertEquals(initialized4, ifInitialized(condFalse, condFalse));
272         assertEquals(initialized4, execute(() -> ifInitialized(condFalse, condFalse)));
273 
274     }
275 
276     static final Tensor<Boolean> TRUE = Tensor.ofScalar(true);
277 
278     @CodeReflection
279     public Tensor<Float> forLoopAdd(Tensor<Long> max, Tensor<Float> initialValue) {
280         return Loop(max, TRUE, initialValue, (i, cond, v) -> new LoopResult<>(cond, Add(v, v)));
281     }
282 
283     @CodeReflection
284     public SingleValueTuple<Tensor<Float>> forLoopAddRecord(Tensor<Long> max, Tensor<Float> initialValue) {
285         return Loop(max, TRUE, new SingleValueTuple<>(initialValue), (i, cond, v) -> new LoopResult<>(cond, new SingleValueTuple<>(Add(v.val(), v.val()))));
286     }
287 
288     @Test
289     public void testForLoopAdd() throws Exception {
290         var expected = Tensor.ofFlat(0f, 8, 16, 24);
291         var value = Tensor.ofFlat(0f, 1, 2, 3);
292         var max = Tensor.ofScalar(3l);
293         assertEquals(expected, forLoopAdd(max, value));
294         assertEquals(expected, execute(() -> forLoopAdd(max, value)));
295         assertEquals(expected, execute(() -> forLoopAddRecord(max, value)).val());
296     }
297 
298     public record Tuple(Tensor<Long> a, Tensor<Float> b) {}
299 
300     @CodeReflection
301     public Tuple loop(Tensor<Boolean> b) {
302         var c1 = Constant(1l);
303         var c2 = Constant(1f);
304         var c3 = Constant(4l);
305         return Loop(c3, b, new Tuple(c1, c2), (i, cond, v) -> new LoopResult<>(Identity(cond), new Tuple(Add(v.a(), v.a()), Identity(Add(v.b(), v.b())))));
306     }
307 
308     @Test
309     public void testLoop() throws Exception {
310         var b = Tensor.ofScalar(true);
311         var res = execute(() -> loop(b));
312         assertEquals(Tensor.ofScalar(16l), res.a());
313         assertEquals(Tensor.ofScalar(16f), res.b());
314     }
315 
316     public record ArgRecord(Tensor<Float> a, Tensor<Float> b) {}
317 
318     @CodeReflection
319     public Tensor<Float> recordArgAdd(ArgRecord arg) {
320         return Add(arg.a(), arg.b());
321     }
322 
323     @Test
324     public void testRecordArgAdd() throws Exception {
325         var arg = new ArgRecord(Tensor.ofFlat(3f), Tensor.ofFlat(4f));
326         assertEquals(recordArgAdd(arg), execute(() -> recordArgAdd(arg)));
327     }
328 
329     @CodeReflection
330     public Tensor<Float> constantArrayArg(Tensor<Float>[] arg) {
331         return Identity(arg[1]);
332     }
333 
334     @Test
335     public void testConstantArrayArg() throws Exception {
336         Tensor<Float>[] arg = new Tensor[]{Tensor.ofFlat(2f), Tensor.ofFlat(3f)};
337         assertEquals(constantArrayArg(arg), execute(() -> constantArrayArg(arg)));
338     }
339 
340     static final Tensor<Float>[] INIT_1_2 = new Tensor[]{Tensor.ofFlat(1f), Tensor.ofFlat(2f)};
341 
342 
343     @CodeReflection
344     public Tensor<Float> constantArrayInit() {
345         return Identity(INIT_1_2[1]);
346     }
347 
348     @Test
349     public void testConstantArrayInit() throws Exception {
350         assertEquals(constantArrayInit(), execute(() -> constantArrayInit()));
351     }
352 
353     @CodeReflection
354     public Tensor<Float>[] constantArrayReturn(Tensor<Float> value) {
355         return new Tensor[]{Identity(value)};
356     }
357 
358     @Test
359     public void testConstantArrayReturn() throws Exception {
360         Tensor<Float> val = Tensor.ofFlat(3f);
361         assertEquals(constantArrayReturn(val), execute(() -> constantArrayReturn(val)));
362     }
363 
364     public record ConstantArrayWrap(Tensor<Float> key, Tensor<Float>[] values) {}
365 
366     @CodeReflection
367     public ConstantArrayWrap constantArrayInRecordReturn(Tensor<Float> key, Tensor<Float> value) {
368         return new ConstantArrayWrap(Identity(key), new Tensor[]{Identity(value)});
369     }
370 
371     @Test
372     public void testConstantArrayInRecordReturn() throws Exception {
373         Tensor<Float> key = Tensor.ofFlat(1f);
374         Tensor<Float> val = Tensor.ofFlat(3f);
375         assertEquals(constantArrayInRecordReturn(key, val).values(), execute(() -> constantArrayInRecordReturn(key, val)).values());
376     }
377 
378     @CodeReflection
379     public Tensor<Long>[] unrollingConstantArrayReturn() {
380         Tensor<Long>[] ret = new Tensor[5];
381         for (int i = 0; i < 5; i++) {
382             ret[i] = Constant((long)i);
383         }
384         return ret;
385     }
386 
387     @Test
388     public void testUnrollingConstantArrayReturn() throws Exception {
389         assertEquals(unrollingConstantArrayReturn(), execute(() -> unrollingConstantArrayReturn()));
390     }
391 
392     static void assertEquals(Tensor[] expected, Tensor[] actual) {
393         Assertions.assertEquals(expected.length, actual.length);
394         for (int i = 0; i < expected.length; i++) {
395             assertEquals(expected[i], actual[i]);
396         }
397     }
398 
399     static void assertEquals(Tensor expected, Tensor actual) {
400 
401         var expectedType = expected.elementType();
402         Assertions.assertSame(expectedType, actual.elementType());
403 
404         Assertions.assertArrayEquals(expected.shape(), actual.shape());
405 
406         switch (expectedType) {
407             case UINT8, INT8, BOOL, UINT4, INT4 ->
408                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_BYTE),
409                                              actual.data().toArray(ValueLayout.JAVA_BYTE));
410             case UINT16, INT16 ->
411                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_SHORT),
412                                              actual.data().toArray(ValueLayout.JAVA_SHORT));
413             case INT32, UINT32 ->
414                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_INT),
415                                              actual.data().toArray(ValueLayout.JAVA_INT));
416             case INT64, UINT64 ->
417                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_LONG),
418                                              actual.data().toArray(ValueLayout.JAVA_LONG));
419             case STRING ->
420                 Assertions.assertEquals(expected.data().getString(0), actual.data().getString(0));
421             case FLOAT ->
422                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_FLOAT),
423                                              actual.data().toArray(ValueLayout.JAVA_FLOAT));
424             case DOUBLE ->
425                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_DOUBLE),
426                                              actual.data().toArray(ValueLayout.JAVA_DOUBLE));
427             default ->
428                 throw new UnsupportedOperationException("Unsupported tensor element type " + expectedType);
429         }
430     }
431 }