1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx;
27
28 import java.lang.foreign.Arena;
29 import java.lang.foreign.MemorySegment;
30 import java.lang.foreign.ValueLayout;
31 import java.util.List;
32 import java.util.stream.LongStream;
33
34 /*
35 class DataType(enum.IntEnum):
36 """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
37
38 # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
39 # but we should stick to the names used in the ONNX spec for consistency.
40 UNDEFINED = 0
41 FLOAT = 1
42 UINT8 = 2
43 INT8 = 3
44 UINT16 = 4
45 INT16 = 5
46 INT32 = 6
47 INT64 = 7
48 STRING = 8
49 BOOL = 9
50 FLOAT16 = 10
51 DOUBLE = 11
52 UINT32 = 12
53 UINT64 = 13
54 COMPLEX64 = 14
55 COMPLEX128 = 15
56 BFLOAT16 = 16
57 FLOAT8E4M3FN = 17
58 FLOAT8E4M3FNUZ = 18
59 FLOAT8E5M2 = 19
60 FLOAT8E5M2FNUZ = 20
61 UINT4 = 21
62 INT4 = 22
63 FLOAT4E2M1 = 23
64 */
65
66 public class Tensor<T> extends OnnxNumber {
67
68 public static final long[] SCALAR_SHAPE = new long[0];
69
70 public static Tensor<Boolean> ofScalar(boolean b) {
71 return ofScalar(Arena.ofAuto(), b);
72 }
73
74 public static Tensor<Boolean> ofScalar(Arena arena, boolean b) {
75 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, b ? (byte)1 : 0), ElementType.BOOL, SCALAR_SHAPE);
76 }
77
78 public static Tensor<Byte> ofScalar(byte b) {
79 return ofShape(SCALAR_SHAPE, b);
80 }
81
82 public static Tensor<Byte> ofScalar(Arena arena, byte b) {
83 return ofShape(arena, SCALAR_SHAPE, b);
84 }
85
86 public static Tensor<Long> ofScalar(long l) {
87 return ofShape(SCALAR_SHAPE, l);
88 }
89
90 public static Tensor<Long> ofScalar(Arena arena, long l) {
91 return ofShape(arena, SCALAR_SHAPE, l);
92 }
93
94 public static Tensor<Float> ofScalar(float f) {
95 return ofShape(SCALAR_SHAPE, f);
96 }
97
98 public static Tensor<Float> ofScalar(Arena arena, float f) {
99 return ofShape(arena, SCALAR_SHAPE, f);
100 }
101
102 public static Tensor<Byte> ofFlat(byte... values) {
103 return ofShape(new long[]{values.length}, values);
104 }
105
106 public static Tensor<Byte> ofFlat(Arena arena, byte... values) {
107 return ofShape(arena, new long[]{values.length}, values);
108 }
109
110 public static Tensor<Integer> ofFlat(int... values) {
111 return ofShape(new long[]{values.length}, values);
112 }
113
114 public static Tensor<Integer> ofFlat(Arena arena, int... values) {
115 return ofShape(arena, new long[]{values.length}, values);
116 }
117
118 public static Tensor<Long> ofFlat(long... values) {
119 return ofShape(new long[]{values.length}, values);
120 }
121
122 public static Tensor<Long> ofFlat(Arena arena, long... values) {
123 return ofShape(arena, new long[]{values.length}, values);
124 }
125
126 public static Tensor<Float> ofFlat(float... values) {
127 return ofShape(new long[]{values.length}, values);
128 }
129
130 public static Tensor<Float> ofFlat(Arena arena, float... values) {
131 return ofShape(arena, new long[]{values.length}, values);
132 }
133
134 public static Tensor<String> ofFlat(String... s) {
135 return ofShape(new long[]{s.length}, s);
136 }
137
138 public static Tensor<String> ofFlat(Arena arena, String... s) {
139 return ofShape(arena, new long[]{s.length}, s);
140 }
141
142 public static Tensor<Byte> ofShape(long[] shape, byte... values) {
143 return ofShape(Arena.ofAuto(), shape, values);
144 }
145
146 public static Tensor<Byte> ofShape(Arena arena, long[] shape, byte... values) {
147 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, values), ElementType.UINT8, shape);
148 }
149
150 public static Tensor<Integer> ofShape(long[] shape, int... values) {
151 return ofShape(Arena.ofAuto(), shape, values);
152 }
153
154 public static Tensor<Integer> ofShape(Arena arena, long[] shape, int... values) {
155 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_INT, values), ElementType.INT32, shape);
156 }
157
158 public static Tensor<Long> ofShape(long[] shape, long... values) {
159 return ofShape(Arena.ofAuto(), shape, values);
160 }
161
162 public static Tensor<Long> ofShape(Arena arena, long[] shape, long... values) {
163 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_LONG, values), ElementType.INT64, shape);
164 }
165
166 public static Tensor<Float> ofShape(long[] shape, float... values) {
167 return ofShape(Arena.ofAuto(), shape, values);
168 }
169
170 public static Tensor<Float> ofShape(Arena arena, long[] shape, float... values) {
171 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_FLOAT, values), ElementType.FLOAT, shape);
172 }
173
174 public static <T> Tensor<T> ofShape(long[] shape, byte[] rawData, ElementType elementType) {
175 return ofShape(Arena.ofAuto(), shape, rawData, elementType);
176 }
177
178 public static <T> Tensor<T> ofShape(Arena arena, long[] shape, byte[] rawData, ElementType elementType) {
179 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, rawData), elementType, shape);
180 }
181
182 public static Tensor<String> ofShape(long[] shape, String... values) {
183 return ofShape(Arena.ofAuto(), shape, values);
184 }
185
186 public static Tensor<String> ofShape(Arena arena, long[] shape, String... values) {
187 MemorySegment tensorAddr = OnnxRuntime.getInstance().createStringTensor(arena, values, shape);
188 return new Tensor(OnnxRuntime.getInstance().tensorData(tensorAddr), tensorAddr);
189 }
190
191 // Mandatory reference to dataAddr to avoid its garbage colletion
192 private final MemorySegment dataAddr;
193 final MemorySegment tensorAddr;
194
195 public Tensor(Arena arena, MemorySegment dataAddr, ElementType type, long[] shape) {
196 this(dataAddr, OnnxRuntime.getInstance().createTensor(arena, dataAddr, type, shape));
197 }
198
199 Tensor(MemorySegment dataAddr, MemorySegment tensorAddr) {
200 this.dataAddr = dataAddr;
201 this.tensorAddr = tensorAddr;
202 }
203
204 public ElementType elementType() {
205 return OnnxRuntime.getInstance().tensorElementType(tensorAddr);
206 }
207
208 public long[] shape() {
209 return OnnxRuntime.getInstance().tensorShape(tensorAddr);
210 }
211
212 public MemorySegment data() {
213 return dataAddr;
214 }
215
216 public List<String> dataAsStrings() {
217 if (elementType() != ElementType.STRING) {
218 throw new IllegalStateException("Not a String tensor.");
219 }
220 return LongStream.range(0, OnnxRuntime.getInstance().tensorShapeElementCount(tensorAddr))
221 .mapToObj(i -> OnnxRuntime.getInstance().stringTensorElement(tensorAddr, i)).toList();
222 }
223
224 public enum ElementType {
225 FLOAT(1, float.class),
226 UINT8(2, byte.class),
227 INT8(3, byte.class),
228 UINT16(4, short.class),
229 INT16(5, short.class),
230 INT32(6, int.class),
231 INT64(7, long.class),
232 STRING(8, String.class),
233 BOOL(9, boolean.class),
234 FLOAT16(10, Object.class),
235 DOUBLE(11, double.class),
236 UINT32(12, int.class),
237 UINT64(13, long.class),
238 COMPLEX64(14, Object.class),
239 COMPLEX128(15, Object.class),
240 BFLOAT16(16, Object.class),
241 FLOAT8E4M3FN(17, Object.class),
242 FLOAT8E4M3FNUZ(18, Object.class),
243 FLOAT8E5M2(19, Object.class),
244 FLOAT8E5M2FNUZ(20, Object.class),
245 UINT4(21, Object.class),
246 INT4(22, Object.class),
247 FLOAT4E2M1(23, Object.class);
248
249 final int id;
250 final Class<?> type;
251
252 ElementType(int id, Class<?> type) {
253 this.id = id;
254 this.type = type;
255 }
256
257 public Class<?> type() {
258 return type;
259 }
260
261 public String onnxName() {
262 return name().toLowerCase();
263 }
264
265 public int bitSize() {
266 return switch (this) {
267 case INT4, UINT4, FLOAT4E2M1 -> 4;
268 case UINT8, INT8, BOOL, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ -> 8;
269 case UINT16, INT16, FLOAT16, BFLOAT16 -> 16;
270 case UINT32, INT32, FLOAT -> 32;
271 case UINT64, INT64, DOUBLE, COMPLEX64 -> 64;
272 case COMPLEX128 -> 128;
273 case STRING -> 192; // @@@ a magic number?
274 };
275 }
276
277 public static ElementType fromOnnxName(String name) {
278 return ElementType.valueOf(name.toUpperCase());
279 }
280
281 public static ElementType fromOnnxId(int id) {
282 return values()[id - 1];
283 }
284 }
285 }