1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package oracle.code.onnx;
25
26 import java.lang.foreign.ValueLayout;
27 import java.util.List;
28 import java.util.Optional;
29 import jdk.incubator.code.CodeReflection;
30 import org.junit.jupiter.api.Assertions;
31 import org.junit.jupiter.api.Test;
32
33 import static java.util.Optional.empty;
34 import static oracle.code.onnx.OnnxOperators.*;
35 import static oracle.code.onnx.OnnxRuntime.execute;
36
37 public class SimpleTest {
38
39 @CodeReflection
40 public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
41 return Add(a, b);
42 }
43
44 @Test
45 public void testAdd() throws Exception {
46 var a = Tensor.ofFlat(1f, 2, 3);
47 assertEquals(
48 add(a, a),
49 execute(() -> add(a, a)));
50 }
51
52 @CodeReflection
53 public Tensor<Float> sub(Tensor<Float> a, Tensor<Float> b) {
54 return Sub(a, b);
55 }
56
57 @Test
58 public void testSub() throws Exception {
59 var b = Tensor.ofFlat(6f, 5, 4);
60 var a = Tensor.ofFlat(1f, 2, 3);
61 assertEquals(
62 sub(a, b),
63 execute(() -> sub(a, b)));
64 }
65
66 @CodeReflection
67 public Tensor<Float> fconstant() {
68 return Constant(-1f);
69 }
70
71 @Test
72 public void testFconstant() throws Exception {
73 // tests the numbers are encoded correctly
74 var expected = Tensor.ofScalar(-1f);
75 assertEquals(expected, fconstant());
76 assertEquals(expected, execute(() -> fconstant()));
77 }
78
79 @CodeReflection
80 public Tensor<Float> fconstants() {
81 return Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
82 }
83
84 @Test
85 public void testFconstants() throws Exception {
86 // tests the numbers are encoded correctly
87 var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
88 assertEquals(expected, fconstants());
89 assertEquals(expected, execute(() -> fconstants()));
90 }
91
92 @CodeReflection
93 public Tensor<Long> lconstant() {
94 return Constant(-1l);
95 }
96
97 @Test
98 public void testLconstant() throws Exception {
99 // tests the numbers are encoded correctly
100 var expected = Tensor.ofScalar(-1l);
101 assertEquals(expected, lconstant());
102 assertEquals(expected, execute(() -> lconstant()));
103 }
104
105 @CodeReflection
106 public Tensor<Long> lconstants() {
107 return Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
108 }
109
110 @Test
111 public void testLconstants() throws Exception {
112 // tests the numbers are encoded correctly
113 var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
114 assertEquals(expected, lconstants());
115 assertEquals(expected, execute(() -> lconstants()));
116 }
117
118 @CodeReflection
119 public Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
120 return Shape(Reshape(data, shape, empty()), empty(), empty());
121 }
122
123 @Test
124 public void testReshapeAndShape() throws Exception {
125 var data = Tensor.ofFlat(1f, 2, 3, 4, 5, 6, 7, 8);
126 var shape = Tensor.ofFlat(2l, 2, 2);
127 assertEquals(
128 reshapeAndShape(data, shape),
129 execute(() -> reshapeAndShape(data, shape)));
130 }
131
132 @CodeReflection
133 public Tensor<Long> indicesOfMaxPool(Tensor<Float> x) {
134 // testing secondary output
135 return MaxPool(x, empty(), empty(), empty(), empty(), empty(), empty(), new long[]{2}).Indices();
136 }
137
138 @Test
139 public void testIndicesOfMaxPool() throws Exception {
140 var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
141 assertEquals(
142 indicesOfMaxPool(x),
143 execute(() -> indicesOfMaxPool(x)));
144 }
145
146 @CodeReflection
147 public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2) {
148 return Concat(List.of(input1, input2), 0);
149 }
150
151 @Test
152 public void testConcat() throws Exception {
153 var input1 = Tensor.ofFlat(1f, 2, 3);
154 var input2 = Tensor.ofFlat(4f, 5);
155 assertEquals(
156 concat(input1, input2),
157 execute(()-> concat(input1, input2)));
158 }
159
160 @CodeReflection
161 public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
162 return Split(input, Optional.of(split), empty(), empty()).get(0);
163 }
164
165 @Test
166 public void testSplit() throws Exception {
167 var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
168 var split = Tensor.ofFlat(5l);
169 assertEquals(
170 split(input, split),
171 execute(()-> split(input, split)));
172 }
173
174 @CodeReflection
175 public Tensor<Float> ifConst(Tensor<Boolean> cond) {
176 return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f))).get(0);
177 }
178
179 @CodeReflection
180 public List<Tensor<Float>> ifConstList(Tensor<Boolean> cond) {
181 return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f)));
182 }
183
184 public record SingleValueTuple<T>(T val) {}
185
186 @CodeReflection
187 public SingleValueTuple<Tensor<Float>> ifConstRecord(Tensor<Boolean> cond) {
188 return If(cond, () -> new SingleValueTuple<>(Constant(1f)), () -> new SingleValueTuple<>(Constant(-1f)));
189 }
190
191 @Test
192 public void testIfConst() throws Exception {
193 var condFalse = Tensor.ofScalar(false);
194 var expFalse = Tensor.ofScalar(-1f);
195 var condTrue = Tensor.ofScalar(true);
196 var expTrue = Tensor.ofScalar(1f);
197
198 assertEquals(expFalse, ifConst(condFalse));
199 assertEquals(expFalse, execute(() -> ifConst(condFalse)));
200
201 assertEquals(expTrue, ifConst(condTrue));
202 assertEquals(expTrue, execute(() -> ifConst(condTrue)));
203
204 assertEquals(expFalse, execute(() -> ifConstList(condFalse)).get(0));
205 assertEquals(expTrue, execute(() -> ifConstList(condTrue)).get(0));
206
207 assertEquals(expFalse, execute(() -> ifConstRecord(condFalse)).val());
208 assertEquals(expTrue, execute(() -> ifConstRecord(condTrue)).val());
209 }
210
211 @CodeReflection
212 public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
213 var falseValue = Constant(-1f);
214 return If(cond, () -> Identity(trueValue), () -> Identity(falseValue));
215 }
216
217 @Test
218 public void testIfCapture() throws Exception {
219 var condFalse = Tensor.ofScalar(false);
220 var expFalse = Tensor.ofScalar(-1f);
221 var condTrue = Tensor.ofScalar(true);
222 var expTrue = Tensor.ofScalar(1f);
223
224 assertEquals(expFalse, ifCapture(condFalse, expTrue));
225 assertEquals(expFalse, execute(() -> ifCapture(condFalse, expTrue)));
226
227 assertEquals(expTrue, ifCapture(condTrue, expTrue));
228 assertEquals(expTrue, execute(() -> ifCapture(condTrue, expTrue)));
229 }
230
231 final Tensor<Float> initialized = Tensor.ofFlat(42f);
232
233 @CodeReflection
234 public Tensor<Float> initialized() {
235 return Identity(initialized);
236 }
237
238 @Test
239 public void testInitialized() throws Exception {
240
241 assertEquals(initialized(),
242 execute(() -> initialized()));
243 }
244
245 final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
246 final Tensor<Float> initialized3 = Tensor.ofFlat(-1f);
247 final Tensor<Float> initialized4 = Tensor.ofFlat(-99f);
248
249 @CodeReflection
250 public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
251 return If(cond1,
252 () -> If(cond2,
253 () -> List.of(Identity(initialized)),
254 () -> List.of(Identity(initialized2))),
255 () -> If(cond2,
256 () -> List.of(Identity(initialized3)),
257 () -> List.of(Identity(initialized4)))).get(0);
258 }
259
260 @Test
261 public void testIfInitialized() throws Exception {
262 var condFalse = Tensor.ofScalar(false);
263 var condTrue = Tensor.ofScalar(true);
264
265 assertEquals(initialized, ifInitialized(condTrue, condTrue));
266 assertEquals(initialized, execute(() -> ifInitialized(condTrue, condTrue)));
267 assertEquals(initialized2, ifInitialized(condTrue, condFalse));
268 assertEquals(initialized2, execute(() -> ifInitialized(condTrue, condFalse)));
269 assertEquals(initialized3, ifInitialized(condFalse, condTrue));
270 assertEquals(initialized3, execute(() -> ifInitialized(condFalse, condTrue)));
271 assertEquals(initialized4, ifInitialized(condFalse, condFalse));
272 assertEquals(initialized4, execute(() -> ifInitialized(condFalse, condFalse)));
273
274 }
275
276 static final Tensor<Boolean> TRUE = Tensor.ofScalar(true);
277
278 @CodeReflection
279 public Tensor<Float> forLoopAdd(Tensor<Long> max, Tensor<Float> initialValue) {
280 return Loop(max, TRUE, initialValue, (i, cond, v) -> new LoopResult<>(cond, Add(v, v)));
281 }
282
283 @CodeReflection
284 public SingleValueTuple<Tensor<Float>> forLoopAddRecord(Tensor<Long> max, Tensor<Float> initialValue) {
285 return Loop(max, TRUE, new SingleValueTuple<>(initialValue), (i, cond, v) -> new LoopResult<>(cond, new SingleValueTuple<>(Add(v.val(), v.val()))));
286 }
287
288 @Test
289 public void testForLoopAdd() throws Exception {
290 var expected = Tensor.ofFlat(0f, 8, 16, 24);
291 var value = Tensor.ofFlat(0f, 1, 2, 3);
292 var max = Tensor.ofScalar(3l);
293 assertEquals(expected, forLoopAdd(max, value));
294 assertEquals(expected, execute(() -> forLoopAdd(max, value)));
295 assertEquals(expected, execute(() -> forLoopAddRecord(max, value)).val());
296 }
297
298 public record Tuple(Tensor<Long> a, Tensor<Float> b) {}
299
300 @CodeReflection
301 public Tuple loop(Tensor<Boolean> b) {
302 var c1 = Constant(1l);
303 var c2 = Constant(1f);
304 var c3 = Constant(4l);
305 return Loop(c3, b, new Tuple(c1, c2), (i, cond, v) -> new LoopResult<>(Identity(cond), new Tuple(Add(v.a(), v.a()), Identity(Add(v.b(), v.b())))));
306 }
307
308 @Test
309 public void testLoop() throws Exception {
310 var b = Tensor.ofScalar(true);
311 var res = execute(() -> loop(b));
312 assertEquals(Tensor.ofScalar(16l), res.a());
313 assertEquals(Tensor.ofScalar(16f), res.b());
314 }
315
316 public record ArgRecord(Tensor<Float> a, Tensor<Float> b) {}
317
318 @CodeReflection
319 public Tensor<Float> recordArgAdd(ArgRecord arg) {
320 return Add(arg.a(), arg.b());
321 }
322
323 @Test
324 public void testRecordArgAdd() throws Exception {
325 var arg = new ArgRecord(Tensor.ofFlat(3f), Tensor.ofFlat(4f));
326 assertEquals(recordArgAdd(arg), execute(() -> recordArgAdd(arg)));
327 }
328
329 @CodeReflection
330 public Tensor<Float> constantArrayArg(Tensor<Float>[] arg) {
331 return Identity(arg[1]);
332 }
333
334 @Test
335 public void testConstantArrayArg() throws Exception {
336 Tensor<Float>[] arg = new Tensor[]{Tensor.ofFlat(2f), Tensor.ofFlat(3f)};
337 assertEquals(constantArrayArg(arg), execute(() -> constantArrayArg(arg)));
338 }
339
340 static final Tensor<Float>[] INIT_1_2 = new Tensor[]{Tensor.ofFlat(1f), Tensor.ofFlat(2f)};
341
342
343 @CodeReflection
344 public Tensor<Float> constantArrayInit() {
345 return Identity(INIT_1_2[1]);
346 }
347
348 @Test
349 public void testConstantArrayInit() throws Exception {
350 assertEquals(constantArrayInit(), execute(() -> constantArrayInit()));
351 }
352
353 @CodeReflection
354 public Tensor<Float>[] constantArrayReturn(Tensor<Float> value) {
355 return new Tensor[]{Identity(value)};
356 }
357
358 @Test
359 public void testConstantArrayReturn() throws Exception {
360 Tensor<Float> val = Tensor.ofFlat(3f);
361 assertEquals(constantArrayReturn(val), execute(() -> constantArrayReturn(val)));
362 }
363
364 public record ConstantArrayWrap(Tensor<Float> key, Tensor<Float>[] values) {}
365
366 @CodeReflection
367 public ConstantArrayWrap constantArrayInRecordReturn(Tensor<Float> key, Tensor<Float> value) {
368 return new ConstantArrayWrap(Identity(key), new Tensor[]{Identity(value)});
369 }
370
371 @Test
372 public void testConstantArrayInRecordReturn() throws Exception {
373 Tensor<Float> key = Tensor.ofFlat(1f);
374 Tensor<Float> val = Tensor.ofFlat(3f);
375 assertEquals(constantArrayInRecordReturn(key, val).values(), execute(() -> constantArrayInRecordReturn(key, val)).values());
376 }
377
378 @CodeReflection
379 public Tensor<Long>[] unrollingConstantArrayReturn() {
380 Tensor<Long>[] ret = new Tensor[5];
381 for (int i = 0; i < 5; i++) {
382 ret[i] = Constant((long)i);
383 }
384 return ret;
385 }
386
387 @Test
388 public void testUnrollingConstantArrayReturn() throws Exception {
389 assertEquals(unrollingConstantArrayReturn(), execute(() -> unrollingConstantArrayReturn()));
390 }
391
392 static void assertEquals(Tensor[] expected, Tensor[] actual) {
393 Assertions.assertEquals(expected.length, actual.length);
394 for (int i = 0; i < expected.length; i++) {
395 assertEquals(expected[i], actual[i]);
396 }
397 }
398
399 static void assertEquals(Tensor expected, Tensor actual) {
400
401 var expectedType = expected.elementType();
402 Assertions.assertSame(expectedType, actual.elementType());
403
404 Assertions.assertArrayEquals(expected.shape(), actual.shape());
405
406 switch (expectedType) {
407 case UINT8, INT8, BOOL, UINT4, INT4 ->
408 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_BYTE),
409 actual.data().toArray(ValueLayout.JAVA_BYTE));
410 case UINT16, INT16 ->
411 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_SHORT),
412 actual.data().toArray(ValueLayout.JAVA_SHORT));
413 case INT32, UINT32 ->
414 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_INT),
415 actual.data().toArray(ValueLayout.JAVA_INT));
416 case INT64, UINT64 ->
417 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_LONG),
418 actual.data().toArray(ValueLayout.JAVA_LONG));
419 case STRING ->
420 Assertions.assertEquals(expected.data().getString(0), actual.data().getString(0));
421 case FLOAT ->
422 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_FLOAT),
423 actual.data().toArray(ValueLayout.JAVA_FLOAT));
424 case DOUBLE ->
425 Assertions.assertArrayEquals(expected.data().toArray(ValueLayout.JAVA_DOUBLE),
426 actual.data().toArray(ValueLayout.JAVA_DOUBLE));
427 default ->
428 throw new UnsupportedOperationException("Unsupported tensor element type " + expectedType);
429 }
430 }
431 }