1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx;
25
26 import java.lang.foreign.ValueLayout;
27 import java.util.List;
28 import java.util.Optional;
29 import jdk.incubator.code.Reflect;
30 import org.junit.jupiter.api.Assertions;
31 import org.junit.jupiter.api.Test;
32
33 import static java.util.Optional.empty;
34 import static oracle.code.onnx.OnnxOperators.*;
35 import static oracle.code.onnx.OnnxRuntime.execute;
36
37 public class SimpleTest {
38
39 @Reflect
40 public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
41 return Add(a, b);
42 }
43
44 @Test
45 @Reflect
46 public void testAdd() throws Exception {
47 var a = Tensor.ofFlat(1f, 2, 3);
48 assertEquals(
49 add(a, a),
50 execute(() -> add(a, a)));
51 }
52
53 @Reflect
54 public Tensor<Float> sub(Tensor<Float> a, Tensor<Float> b) {
55 return Sub(a, b);
56 }
57
58 @Test
59 @Reflect
60 public void testSub() throws Exception {
61 var b = Tensor.ofFlat(6f, 5, 4);
62 var a = Tensor.ofFlat(1f, 2, 3);
63 assertEquals(
64 sub(a, b),
65 execute(() -> sub(a, b)));
66 }
67
68 @Reflect
69 public Tensor<Float> fconstant() {
70 return Constant(-1f);
71 }
72
73 @Test
74 @Reflect
75 public void testFconstant() throws Exception {
76 // tests the numbers are encoded correctly
77 var expected = Tensor.ofScalar(-1f);
78 assertEquals(expected, fconstant());
79 assertEquals(expected, execute(() -> fconstant()));
80 }
81
82 @Reflect
83 public Tensor<Float> fconstants() {
84 return Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
85 }
86
87 @Test
88 @Reflect
89 public void testFconstants() throws Exception {
90 // tests the numbers are encoded correctly
91 var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
92 assertEquals(expected, fconstants());
93 assertEquals(expected, execute(() -> fconstants()));
94 }
95
96 @Reflect
97 public Tensor<Long> lconstant() {
98 return Constant(-1l);
99 }
100
101 @Test
102 @Reflect
103 public void testLconstant() throws Exception {
104 // tests the numbers are encoded correctly
105 var expected = Tensor.ofScalar(-1l);
106 assertEquals(expected, lconstant());
107 assertEquals(expected, execute(() -> lconstant()));
108 }
109
110 @Reflect
111 public Tensor<Long> lconstants() {
112 return Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
113 }
114
115 @Test
116 @Reflect
117 public void testLconstants() throws Exception {
118 // tests the numbers are encoded correctly
119 var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
120 assertEquals(expected, lconstants());
121 assertEquals(expected, execute(() -> lconstants()));
122 }
123
124 @Reflect
125 public Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
126 return Shape(Reshape(data, shape, empty()), empty(), empty());
127 }
128
129 @Test
130 @Reflect
131 public void testReshapeAndShape() throws Exception {
132 var data = Tensor.ofFlat(1f, 2, 3, 4, 5, 6, 7, 8);
133 var shape = Tensor.ofFlat(2l, 2, 2);
134 assertEquals(
135 reshapeAndShape(data, shape),
136 execute(() -> reshapeAndShape(data, shape)));
137 }
138
139 @Reflect
140 public Tensor<Long> indicesOfMaxPool(Tensor<Float> x) {
141 // testing secondary output
142 return MaxPool(x, empty(), empty(), empty(), empty(), empty(), empty(), new long[]{2}).Indices();
143 }
144
145 @Test
146 @Reflect
147 public void testIndicesOfMaxPool() throws Exception {
148 var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
149 assertEquals(
150 indicesOfMaxPool(x),
151 execute(() -> indicesOfMaxPool(x)));
152 }
153
154 @Reflect
155 public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2) {
156 return Concat(List.of(input1, input2), 0);
157 }
158
159 @Test
160 @Reflect
161 public void testConcat() throws Exception {
162 var input1 = Tensor.ofFlat(1f, 2, 3);
163 var input2 = Tensor.ofFlat(4f, 5);
164 assertEquals(
165 concat(input1, input2),
166 execute(()-> concat(input1, input2)));
167 }
168
169 @Reflect
170 public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
171 return Split(input, Optional.of(split), empty(), empty()).get(0);
172 }
173
174 @Test
175 @Reflect
176 public void testSplit() throws Exception {
177 var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
178 var split = Tensor.ofFlat(5l);
179 assertEquals(
180 split(input, split),
181 execute(()-> split(input, split)));
182 }
183
184 @Reflect
185 public Tensor<Float> ifConst(Tensor<Boolean> cond) {
186 return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f))).get(0);
187 }
188
189 @Reflect
190 public List<Tensor<Float>> ifConstList(Tensor<Boolean> cond) {
191 return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f)));
192 }
193
194 public record SingleValueTuple<T>(T val) {}
195
196 @Reflect
197 public SingleValueTuple<Tensor<Float>> ifConstRecord(Tensor<Boolean> cond) {
198 return If(cond, () -> new SingleValueTuple<>(Constant(1f)), () -> new SingleValueTuple<>(Constant(-1f)));
199 }
200
201 @Test
202 @Reflect
203 public void testIfConst() throws Exception {
204 var condFalse = Tensor.ofScalar(false);
205 var expFalse = Tensor.ofScalar(-1f);
206 var condTrue = Tensor.ofScalar(true);
207 var expTrue = Tensor.ofScalar(1f);
208
209 assertEquals(expFalse, ifConst(condFalse));
210 assertEquals(expFalse, execute(() -> ifConst(condFalse)));
211
212 assertEquals(expTrue, ifConst(condTrue));
213 assertEquals(expTrue, execute(() -> ifConst(condTrue)));
214
215 assertEquals(expFalse, execute(() -> ifConstList(condFalse)).get(0));
216 assertEquals(expTrue, execute(() -> ifConstList(condTrue)).get(0));
217
218 assertEquals(expFalse, execute(() -> ifConstRecord(condFalse)).val());
219 assertEquals(expTrue, execute(() -> ifConstRecord(condTrue)).val());
220 }
221
222 @Reflect
223 public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
224 var falseValue = Constant(-1f);
225 return If(cond, () -> Identity(trueValue), () -> Identity(falseValue));
226 }
227
228 @Test
229 @Reflect
230 public void testIfCapture() throws Exception {
231 var condFalse = Tensor.ofScalar(false);
232 var expFalse = Tensor.ofScalar(-1f);
233 var condTrue = Tensor.ofScalar(true);
234 var expTrue = Tensor.ofScalar(1f);
235
236 assertEquals(expFalse, ifCapture(condFalse, expTrue));
237 assertEquals(expFalse, execute(() -> ifCapture(condFalse, expTrue)));
238
239 assertEquals(expTrue, ifCapture(condTrue, expTrue));
240 assertEquals(expTrue, execute(() -> ifCapture(condTrue, expTrue)));
241 }
242
243 final Tensor<Float> initialized = Tensor.ofFlat(42f);
244
245 @Reflect
246 public Tensor<Float> initialized() {
247 return Identity(initialized);
248 }
249
250 @Test
251 @Reflect
252 public void testInitialized() throws Exception {
253
254 assertEquals(initialized(),
255 execute(() -> initialized()));
256 }
257
258 final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
259 final Tensor<Float> initialized3 = Tensor.ofFlat(-1f);
260 final Tensor<Float> initialized4 = Tensor.ofFlat(-99f);
261
262 @Reflect
263 public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
264 return If(cond1,
265 () -> If(cond2,
266 () -> List.of(Identity(initialized)),
267 () -> List.of(Identity(initialized2))),
268 () -> If(cond2,
269 () -> List.of(Identity(initialized3)),
270 () -> List.of(Identity(initialized4)))).get(0);
271 }
272
273 @Test
274 @Reflect
275 public void testIfInitialized() throws Exception {
276 var condFalse = Tensor.ofScalar(false);
277 var condTrue = Tensor.ofScalar(true);
278
279 assertEquals(initialized, ifInitialized(condTrue, condTrue));
280 assertEquals(initialized, execute(() -> ifInitialized(condTrue, condTrue)));
281 assertEquals(initialized2, ifInitialized(condTrue, condFalse));
282 assertEquals(initialized2, execute(() -> ifInitialized(condTrue, condFalse)));
283 assertEquals(initialized3, ifInitialized(condFalse, condTrue));
284 assertEquals(initialized3, execute(() -> ifInitialized(condFalse, condTrue)));
285 assertEquals(initialized4, ifInitialized(condFalse, condFalse));
286 assertEquals(initialized4, execute(() -> ifInitialized(condFalse, condFalse)));
287
288 }
289
290 static final Tensor<Boolean> TRUE = Tensor.ofScalar(true);
291
292 @Reflect
293 public Tensor<Float> forLoopAdd(Tensor<Long> max, Tensor<Float> initialValue) {
294 return Loop(max, TRUE, initialValue, (i, cond, v) -> new LoopResult<>(cond, Add(v, v)));
295 }
296
297 @Reflect
298 public SingleValueTuple<Tensor<Float>> forLoopAddRecord(Tensor<Long> max, Tensor<Float> initialValue) {
299 return Loop(max, TRUE, new SingleValueTuple<>(initialValue), (i, cond, v) -> new LoopResult<>(cond, new SingleValueTuple<>(Add(v.val(), v.val()))));
300 }
301
302 @Test
303 @Reflect
304 public void testForLoopAdd() throws Exception {
305 var expected = Tensor.ofFlat(0f, 8, 16, 24);
306 var value = Tensor.ofFlat(0f, 1, 2, 3);
307 var max = Tensor.ofScalar(3l);
308 assertEquals(expected, forLoopAdd(max, value));
309 assertEquals(expected, execute(() -> forLoopAdd(max, value)));
310 assertEquals(expected, execute(() -> forLoopAddRecord(max, value)).val());
311 }
312
313 public record Tuple(Tensor<Long> a, Tensor<Float> b) {}
314
315 @Reflect
316 public Tuple loop(Tensor<Boolean> b) {
317 var c1 = Constant(1l);
318 var c2 = Constant(1f);
319 var c3 = Constant(4l);
320 return Loop(c3, b, new Tuple(c1, c2), (i, cond, v) -> new LoopResult<>(Identity(cond), new Tuple(Add(v.a(), v.a()), Identity(Add(v.b(), v.b())))));
321 }
322
323 @Test
324 @Reflect
325 public void testLoop() throws Exception {
326 var b = Tensor.ofScalar(true);
327 var res = execute(() -> loop(b));
328 assertEquals(Tensor.ofScalar(16l), res.a());
329 assertEquals(Tensor.ofScalar(16f), res.b());
330 }
331
332 public record ArgRecord(Tensor<Float> a, Tensor<Float> b) {}
333
334 @Reflect
335 public Tensor<Float> recordArgAdd(ArgRecord arg) {
336 return Add(arg.a(), arg.b());
337 }
338
339 @Test
340 @Reflect
341 public void testRecordArgAdd() throws Exception {
342 var arg = new ArgRecord(Tensor.ofFlat(3f), Tensor.ofFlat(4f));
343 assertEquals(recordArgAdd(arg), execute(() -> recordArgAdd(arg)));
344 }
345
346 @Reflect
347 public Tensor<Float> constantArrayArg(Tensor<Float>[] arg) {
348 return Identity(arg[1]);
349 }
350
351 @Test
352 @Reflect
353 public void testConstantArrayArg() throws Exception {
354 Tensor<Float>[] arg = new Tensor[]{Tensor.ofFlat(2f), Tensor.ofFlat(3f)};
355 assertEquals(constantArrayArg(arg), execute(() -> constantArrayArg(arg)));
356 }
357
358 static final Tensor<Float>[] INIT_1_2 = new Tensor[]{Tensor.ofFlat(1f), Tensor.ofFlat(2f)};
359
360
361 @Reflect
362 public Tensor<Float> constantArrayInit() {
363 return Identity(INIT_1_2[1]);
364 }
365
366 @Test
367 @Reflect
368 public void testConstantArrayInit() throws Exception {
369 assertEquals(constantArrayInit(), execute(() -> constantArrayInit()));
370 }
371
372 @Reflect
373 public Tensor<Float>[] constantArrayReturn(Tensor<Float> value) {
374 return new Tensor[]{Identity(value)};
375 }
376
377 @Test
378 @Reflect
379 public void testConstantArrayReturn() throws Exception {
380 Tensor<Float> val = Tensor.ofFlat(3f);
381 assertEquals(constantArrayReturn(val), execute(() -> constantArrayReturn(val)));
382 }
383
384 public record ConstantArrayWrap(Tensor<Float> key, Tensor<Float>[] values) {}
385
386 @Reflect
387 public ConstantArrayWrap constantArrayInRecordReturn(Tensor<Float> key, Tensor<Float> value) {
388 return new ConstantArrayWrap(Identity(key), new Tensor[]{Identity(value)});
389 }
390
391 @Test
392 @Reflect
393 public void testConstantArrayInRecordReturn() throws Exception {
394 Tensor<Float> key = Tensor.ofFlat(1f);
395 Tensor<Float> val = Tensor.ofFlat(3f);
396 assertEquals(constantArrayInRecordReturn(key, val).values(), execute(() -> constantArrayInRecordReturn(key, val)).values());
397 }
398
399 @Reflect
400 public Tensor<Long>[] unrollingConstantArrayReturn() {
401 Tensor<Long>[] ret = new Tensor[5];
402 for (int i = 0; i < 5; i++) {
403 ret[i] = Constant((long)i);
404 }
405 return ret;
406 }
407
408 @Test
409 @Reflect
410 public void testUnrollingConstantArrayReturn() throws Exception {
411 assertEquals(unrollingConstantArrayReturn(), execute(() -> unrollingConstantArrayReturn()));
412 }
413
414 static void assertEquals(Tensor[] expected, Tensor[] actual) {
415 Assertions.assertEquals(expected.length, actual.length);
416 for (int i = 0; i < expected.length; i++) {
417 assertEquals(expected[i], actual[i]);
418 }
419 }
420
421 static void assertEquals(Tensor expected, Tensor actual) {
422
423 var expectedType = expected.elementType();
424 Assertions.assertSame(expectedType, actual.elementType());
425
426 Assertions.assertArrayEquals(expected.shape(), actual.shape());
427
428 switch (expectedType) {
429 case UINT8, INT8, BOOL, UINT4, INT4 ->
430 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_BYTE),
431 actual.data().toArray(ValueLayout.JAVA_BYTE));
432 case UINT16, INT16 ->
433 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_SHORT),
434 actual.data().toArray(ValueLayout.JAVA_SHORT));
435 case INT32, UINT32 ->
436 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_INT),
437 actual.data().toArray(ValueLayout.JAVA_INT));
438 case INT64, UINT64 ->
439 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_LONG),
440 actual.data().toArray(ValueLayout.JAVA_LONG));
441 case STRING ->
442 Assertions.assertEquals(expected.data().getString(0), actual.data().getString(0));
443 case FLOAT ->
444 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_FLOAT),
445 actual.data().toArray(ValueLayout.JAVA_FLOAT));
446 case DOUBLE ->
447 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_DOUBLE),
448 actual.data().toArray(ValueLayout.JAVA_DOUBLE));
449 default ->
450 throw new UnsupportedOperationException("Unsupported tensor element type " + expectedType);
451 }
452 }
453 }