1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.onnx;
 25 
 26 import java.lang.foreign.ValueLayout;
 27 import java.util.List;
 28 import java.util.Optional;
 29 import jdk.incubator.code.Reflect;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.Test;
 32 
 33 import static java.util.Optional.empty;
 34 import static oracle.code.onnx.OnnxOperators.*;
 35 import static oracle.code.onnx.OnnxRuntime.execute;
 36 
 37 public class SimpleTest {
 38 
 39     @Reflect
 40     public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
 41         return Add(a, b);
 42     }
 43 
 44     @Test
 45     @Reflect
 46     public void testAdd() throws Exception {
 47         var a = Tensor.ofFlat(1f, 2, 3);
 48         assertEquals(
 49                 add(a, a),
 50                 execute(() -> add(a, a)));
 51     }
 52 
 53     @Reflect
 54     public Tensor<Float> sub(Tensor<Float> a, Tensor<Float> b) {
 55         return Sub(a, b);
 56     }
 57 
 58     @Test
 59     @Reflect
 60     public void testSub() throws Exception {
 61         var b = Tensor.ofFlat(6f, 5, 4);
 62         var a = Tensor.ofFlat(1f, 2, 3);
 63         assertEquals(
 64                 sub(a, b),
 65                 execute(() -> sub(a, b)));
 66     }
 67 
 68     @Reflect
 69     public Tensor<Float> fconstant() {
 70         return Constant(-1f);
 71     }
 72 
 73     @Test
 74     @Reflect
 75     public void testFconstant() throws Exception {
 76         // tests the numbers are encoded correctly
 77         var expected = Tensor.ofScalar(-1f);
 78         assertEquals(expected, fconstant());
 79         assertEquals(expected, execute(() -> fconstant()));
 80     }
 81 
 82     @Reflect
 83     public Tensor<Float> fconstants() {
 84         return Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
 85     }
 86 
 87     @Test
 88     @Reflect
 89     public void testFconstants() throws Exception {
 90         // tests the numbers are encoded correctly
 91         var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
 92         assertEquals(expected, fconstants());
 93         assertEquals(expected, execute(() -> fconstants()));
 94     }
 95 
 96     @Reflect
 97     public Tensor<Long> lconstant() {
 98         return Constant(-1l);
 99     }
100 
101     @Test
102     @Reflect
103     public void testLconstant() throws Exception {
104         // tests the numbers are encoded correctly
105         var expected = Tensor.ofScalar(-1l);
106         assertEquals(expected, lconstant());
107         assertEquals(expected, execute(() -> lconstant()));
108     }
109 
110     @Reflect
111     public Tensor<Long> lconstants() {
112         return Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
113     }
114 
115     @Test
116     @Reflect
117     public void testLconstants() throws Exception {
118         // tests the numbers are encoded correctly
119         var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
120         assertEquals(expected, lconstants());
121         assertEquals(expected, execute(() -> lconstants()));
122     }
123 
124     @Reflect
125     public Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
126         return Shape(Reshape(data, shape, empty()), empty(), empty());
127     }
128 
129     @Test
130     @Reflect
131     public void testReshapeAndShape() throws Exception {
132         var data = Tensor.ofFlat(1f, 2, 3, 4, 5, 6, 7, 8);
133         var shape = Tensor.ofFlat(2l, 2, 2);
134         assertEquals(
135                 reshapeAndShape(data, shape),
136                 execute(() -> reshapeAndShape(data, shape)));
137     }
138 
139     @Reflect
140     public Tensor<Long> indicesOfMaxPool(Tensor<Float> x) {
141         // testing secondary output
142         return MaxPool(x, empty(), empty(), empty(), empty(), empty(), empty(),  new long[]{2}).Indices();
143     }
144 
145     @Test
146     @Reflect
147     public void testIndicesOfMaxPool() throws Exception {
148         var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
149         assertEquals(
150                 indicesOfMaxPool(x),
151                 execute(() -> indicesOfMaxPool(x)));
152     }
153 
154     @Reflect
155     public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2) {
156         return Concat(List.of(input1, input2), 0);
157     }
158 
159     @Test
160     @Reflect
161     public void testConcat() throws Exception {
162         var input1 = Tensor.ofFlat(1f, 2, 3);
163         var input2 = Tensor.ofFlat(4f, 5);
164         assertEquals(
165                 concat(input1, input2),
166                 execute(()-> concat(input1, input2)));
167     }
168 
169     @Reflect
170     public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
171         return Split(input, Optional.of(split), empty(), empty()).get(0);
172     }
173 
174     @Test
175     @Reflect
176     public void testSplit() throws Exception {
177         var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
178         var split = Tensor.ofFlat(5l);
179         assertEquals(
180                 split(input, split),
181                 execute(()-> split(input, split)));
182     }
183 
184     @Reflect
185     public Tensor<Float> ifConst(Tensor<Boolean> cond) {
186         return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f))).get(0);
187     }
188 
189     @Reflect
190     public List<Tensor<Float>> ifConstList(Tensor<Boolean> cond) {
191         return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f)));
192     }
193 
194     public record SingleValueTuple<T>(T val) {}
195 
196     @Reflect
197     public SingleValueTuple<Tensor<Float>> ifConstRecord(Tensor<Boolean> cond) {
198         return If(cond, () -> new SingleValueTuple<>(Constant(1f)), () -> new SingleValueTuple<>(Constant(-1f)));
199     }
200 
201     @Test
202     @Reflect
203     public void testIfConst() throws Exception {
204         var condFalse = Tensor.ofScalar(false);
205         var expFalse = Tensor.ofScalar(-1f);
206         var condTrue = Tensor.ofScalar(true);
207         var expTrue = Tensor.ofScalar(1f);
208 
209         assertEquals(expFalse, ifConst(condFalse));
210         assertEquals(expFalse, execute(() -> ifConst(condFalse)));
211 
212         assertEquals(expTrue, ifConst(condTrue));
213         assertEquals(expTrue, execute(() -> ifConst(condTrue)));
214 
215         assertEquals(expFalse, execute(() -> ifConstList(condFalse)).get(0));
216         assertEquals(expTrue, execute(() -> ifConstList(condTrue)).get(0));
217 
218         assertEquals(expFalse, execute(() -> ifConstRecord(condFalse)).val());
219         assertEquals(expTrue, execute(() -> ifConstRecord(condTrue)).val());
220     }
221 
222     @Reflect
223     public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
224         var falseValue = Constant(-1f);
225         return If(cond, () -> Identity(trueValue), () -> Identity(falseValue));
226     }
227 
228     @Test
229     @Reflect
230     public void testIfCapture() throws Exception {
231         var condFalse = Tensor.ofScalar(false);
232         var expFalse = Tensor.ofScalar(-1f);
233         var condTrue = Tensor.ofScalar(true);
234         var expTrue = Tensor.ofScalar(1f);
235 
236         assertEquals(expFalse, ifCapture(condFalse, expTrue));
237         assertEquals(expFalse, execute(() -> ifCapture(condFalse, expTrue)));
238 
239         assertEquals(expTrue, ifCapture(condTrue, expTrue));
240         assertEquals(expTrue, execute(() -> ifCapture(condTrue, expTrue)));
241     }
242 
243     final Tensor<Float> initialized = Tensor.ofFlat(42f);
244 
245     @Reflect
246     public Tensor<Float> initialized() {
247         return Identity(initialized);
248     }
249 
250     @Test
251     @Reflect
252     public void testInitialized() throws Exception {
253 
254         assertEquals(initialized(),
255                      execute(() -> initialized()));
256     }
257 
258     final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
259     final Tensor<Float> initialized3 = Tensor.ofFlat(-1f);
260     final Tensor<Float> initialized4 = Tensor.ofFlat(-99f);
261 
262     @Reflect
263     public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
264         return If(cond1,
265                 () -> If(cond2,
266                         () -> List.of(Identity(initialized)),
267                         () -> List.of(Identity(initialized2))),
268                 () -> If(cond2,
269                         () -> List.of(Identity(initialized3)),
270                         () -> List.of(Identity(initialized4)))).get(0);
271     }
272 
273     @Test
274     @Reflect
275     public void testIfInitialized() throws Exception {
276         var condFalse = Tensor.ofScalar(false);
277         var condTrue = Tensor.ofScalar(true);
278 
279         assertEquals(initialized, ifInitialized(condTrue, condTrue));
280         assertEquals(initialized, execute(() -> ifInitialized(condTrue, condTrue)));
281         assertEquals(initialized2, ifInitialized(condTrue, condFalse));
282         assertEquals(initialized2, execute(() -> ifInitialized(condTrue, condFalse)));
283         assertEquals(initialized3, ifInitialized(condFalse, condTrue));
284         assertEquals(initialized3, execute(() -> ifInitialized(condFalse, condTrue)));
285         assertEquals(initialized4, ifInitialized(condFalse, condFalse));
286         assertEquals(initialized4, execute(() -> ifInitialized(condFalse, condFalse)));
287 
288     }
289 
290     static final Tensor<Boolean> TRUE = Tensor.ofScalar(true);
291 
292     @Reflect
293     public Tensor<Float> forLoopAdd(Tensor<Long> max, Tensor<Float> initialValue) {
294         return Loop(max, TRUE, initialValue, (i, cond, v) -> new LoopResult<>(cond, Add(v, v)));
295     }
296 
297     @Reflect
298     public SingleValueTuple<Tensor<Float>> forLoopAddRecord(Tensor<Long> max, Tensor<Float> initialValue) {
299         return Loop(max, TRUE, new SingleValueTuple<>(initialValue), (i, cond, v) -> new LoopResult<>(cond, new SingleValueTuple<>(Add(v.val(), v.val()))));
300     }
301 
302     @Test
303     @Reflect
304     public void testForLoopAdd() throws Exception {
305         var expected = Tensor.ofFlat(0f, 8, 16, 24);
306         var value = Tensor.ofFlat(0f, 1, 2, 3);
307         var max = Tensor.ofScalar(3l);
308         assertEquals(expected, forLoopAdd(max, value));
309         assertEquals(expected, execute(() -> forLoopAdd(max, value)));
310         assertEquals(expected, execute(() -> forLoopAddRecord(max, value)).val());
311     }
312 
313     public record Tuple(Tensor<Long> a, Tensor<Float> b) {}
314 
315     @Reflect
316     public Tuple loop(Tensor<Boolean> b) {
317         var c1 = Constant(1l);
318         var c2 = Constant(1f);
319         var c3 = Constant(4l);
320         return Loop(c3, b, new Tuple(c1, c2), (i, cond, v) -> new LoopResult<>(Identity(cond), new Tuple(Add(v.a(), v.a()), Identity(Add(v.b(), v.b())))));
321     }
322 
323     @Test
324     @Reflect
325     public void testLoop() throws Exception {
326         var b = Tensor.ofScalar(true);
327         var res = execute(() -> loop(b));
328         assertEquals(Tensor.ofScalar(16l), res.a());
329         assertEquals(Tensor.ofScalar(16f), res.b());
330     }
331 
332     public record ArgRecord(Tensor<Float> a, Tensor<Float> b) {}
333 
334     @Reflect
335     public Tensor<Float> recordArgAdd(ArgRecord arg) {
336         return Add(arg.a(), arg.b());
337     }
338 
339     @Test
340     @Reflect
341     public void testRecordArgAdd() throws Exception {
342         var arg = new ArgRecord(Tensor.ofFlat(3f), Tensor.ofFlat(4f));
343         assertEquals(recordArgAdd(arg), execute(() -> recordArgAdd(arg)));
344     }
345 
346     @Reflect
347     public Tensor<Float> constantArrayArg(Tensor<Float>[] arg) {
348         return Identity(arg[1]);
349     }
350 
351     @Test
352     @Reflect
353     public void testConstantArrayArg() throws Exception {
354         Tensor<Float>[] arg = new Tensor[]{Tensor.ofFlat(2f), Tensor.ofFlat(3f)};
355         assertEquals(constantArrayArg(arg), execute(() -> constantArrayArg(arg)));
356     }
357 
358     static final Tensor<Float>[] INIT_1_2 = new Tensor[]{Tensor.ofFlat(1f), Tensor.ofFlat(2f)};
359 
360 
361     @Reflect
362     public Tensor<Float> constantArrayInit() {
363         return Identity(INIT_1_2[1]);
364     }
365 
366     @Test
367     @Reflect
368     public void testConstantArrayInit() throws Exception {
369         assertEquals(constantArrayInit(), execute(() -> constantArrayInit()));
370     }
371 
372     @Reflect
373     public Tensor<Float>[] constantArrayReturn(Tensor<Float> value) {
374         return new Tensor[]{Identity(value)};
375     }
376 
377     @Test
378     @Reflect
379     public void testConstantArrayReturn() throws Exception {
380         Tensor<Float> val = Tensor.ofFlat(3f);
381         assertEquals(constantArrayReturn(val), execute(() -> constantArrayReturn(val)));
382     }
383 
384     public record ConstantArrayWrap(Tensor<Float> key, Tensor<Float>[] values) {}
385 
386     @Reflect
387     public ConstantArrayWrap constantArrayInRecordReturn(Tensor<Float> key, Tensor<Float> value) {
388         return new ConstantArrayWrap(Identity(key), new Tensor[]{Identity(value)});
389     }
390 
391     @Test
392     @Reflect
393     public void testConstantArrayInRecordReturn() throws Exception {
394         Tensor<Float> key = Tensor.ofFlat(1f);
395         Tensor<Float> val = Tensor.ofFlat(3f);
396         assertEquals(constantArrayInRecordReturn(key, val).values(), execute(() -> constantArrayInRecordReturn(key, val)).values());
397     }
398 
399     @Reflect
400     public Tensor<Long>[] unrollingConstantArrayReturn() {
401         Tensor<Long>[] ret = new Tensor[5];
402         for (int i = 0; i < 5; i++) {
403             ret[i] = Constant((long)i);
404         }
405         return ret;
406     }
407 
408     @Test
409     @Reflect
410     public void testUnrollingConstantArrayReturn() throws Exception {
411         assertEquals(unrollingConstantArrayReturn(), execute(() -> unrollingConstantArrayReturn()));
412     }
413 
414     static void assertEquals(Tensor[] expected, Tensor[] actual) {
415         Assertions.assertEquals(expected.length, actual.length);
416         for (int i = 0; i < expected.length; i++) {
417             assertEquals(expected[i], actual[i]);
418         }
419     }
420 
421     static void assertEquals(Tensor expected, Tensor actual) {
422 
423         var expectedType = expected.elementType();
424         Assertions.assertSame(expectedType, actual.elementType());
425 
426         Assertions.assertArrayEquals(expected.shape(), actual.shape());
427 
428         switch (expectedType) {
429             case UINT8, INT8, BOOL, UINT4, INT4 ->
430                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_BYTE),
431                                              actual.data().toArray(ValueLayout.JAVA_BYTE));
432             case UINT16, INT16 ->
433                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_SHORT),
434                                              actual.data().toArray(ValueLayout.JAVA_SHORT));
435             case INT32, UINT32 ->
436                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_INT),
437                                              actual.data().toArray(ValueLayout.JAVA_INT));
438             case INT64, UINT64 ->
439                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_LONG),
440                                              actual.data().toArray(ValueLayout.JAVA_LONG));
441             case STRING ->
442                 Assertions.assertEquals(expected.data().getString(0), actual.data().getString(0));
443             case FLOAT ->
444                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_FLOAT),
445                                              actual.data().toArray(ValueLayout.JAVA_FLOAT));
446             case DOUBLE ->
447                 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_DOUBLE),
448                                              actual.data().toArray(ValueLayout.JAVA_DOUBLE));
449             default ->
450                 throw new UnsupportedOperationException("Unsupported tensor element type " + expectedType);
451         }
452     }
453 }