1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx;
27
28 import java.lang.foreign.Arena;
29 import java.lang.foreign.MemorySegment;
30 import java.lang.foreign.ValueLayout;
31
32 /*
33 class DataType(enum.IntEnum):
34 """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
35
36 # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
37 # but we should stick to the names used in the ONNX spec for consistency.
38 UNDEFINED = 0
39 FLOAT = 1
40 UINT8 = 2
41 INT8 = 3
42 UINT16 = 4
43 INT16 = 5
44 INT32 = 6
45 INT64 = 7
46 STRING = 8
47 BOOL = 9
48 FLOAT16 = 10
49 DOUBLE = 11
50 UINT32 = 12
51 UINT64 = 13
52 COMPLEX64 = 14
53 COMPLEX128 = 15
54 BFLOAT16 = 16
55 FLOAT8E4M3FN = 17
56 FLOAT8E4M3FNUZ = 18
57 FLOAT8E5M2 = 19
58 FLOAT8E5M2FNUZ = 20
59 UINT4 = 21
60 INT4 = 22
61 FLOAT4E2M1 = 23
62 */
63
64 public class Tensor<T> extends OnnxNumber {
65
66 public static final long[] SCALAR_SHAPE = new long[0];
67
68 public static Tensor<Boolean> ofScalar(boolean b) {
69 return ofScalar(Arena.ofAuto(), b);
70 }
71
72 public static Tensor<Boolean> ofScalar(Arena arena, boolean b) {
73 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, b ? (byte)1 : 0), ElementType.BOOL, SCALAR_SHAPE);
74 }
75
76 public static Tensor<Byte> ofScalar(byte b) {
77 return ofShape(SCALAR_SHAPE, b);
78 }
79
80 public static Tensor<Byte> ofScalar(Arena arena, byte b) {
81 return ofShape(arena, SCALAR_SHAPE, b);
82 }
83
84 public static Tensor<Long> ofScalar(long l) {
85 return ofShape(SCALAR_SHAPE, l);
86 }
87
88 public static Tensor<Long> ofScalar(Arena arena, long l) {
89 return ofShape(arena, SCALAR_SHAPE, l);
90 }
91
92 public static Tensor<Float> ofScalar(float f) {
93 return ofShape(SCALAR_SHAPE, f);
94 }
95
96 public static Tensor<Float> ofScalar(Arena arena, float f) {
97 return ofShape(arena, SCALAR_SHAPE, f);
98 }
99
100
101 public static Tensor<Byte> ofFlat(byte... values) {
102 return ofShape(new long[]{values.length}, values);
103 }
104
105 public static Tensor<Byte> ofFlat(Arena arena, byte... values) {
106 return ofShape(arena, new long[]{values.length}, values);
107 }
108
109 public static Tensor<Long> ofFlat(long... values) {
110 return ofShape(new long[]{values.length}, values);
111 }
112
113 public static Tensor<Long> ofFlat(Arena arena, long... values) {
114 return ofShape(arena, new long[]{values.length}, values);
115 }
116
117 public static Tensor<Float> ofFlat(float... values) {
118 return ofShape(new long[]{values.length}, values);
119 }
120
121 public static Tensor<Float> ofFlat(Arena arena, float... values) {
122 return ofShape(arena, new long[]{values.length}, values);
123 }
124
125 public static Tensor<Byte> ofShape(long[] shape, byte... values) {
126 return ofShape(Arena.ofAuto(), shape, values);
127 }
128
129 public static Tensor<Byte> ofShape(Arena arena, long[] shape, byte... values) {
130 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, values), ElementType.UINT8, shape);
131 }
132
133 public static Tensor<Long> ofShape(long[] shape, long... values) {
134 return ofShape(Arena.ofAuto(), shape, values);
135 }
136
137 public static Tensor<Long> ofShape(Arena arena, long[] shape, long... values) {
138 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_LONG, values), ElementType.INT64, shape);
139 }
140
141 public static Tensor<Float> ofShape(long[] shape, float... values) {
142 return ofShape(Arena.ofAuto(), shape, values);
143 }
144
145 public static Tensor<Float> ofShape(Arena arena, long[] shape, float... values) {
146 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_FLOAT, values), ElementType.FLOAT, shape);
147 }
148
149 public static <T> Tensor<T> ofShape(long[] shape, byte[] rawData, ElementType elementType) {
150 return ofShape(Arena.ofAuto(), shape, rawData, elementType);
151 }
152
153 public static <T> Tensor<T> ofShape(Arena arena, long[] shape, byte[] rawData, ElementType elementType) {
154 return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, rawData), elementType, shape);
155 }
156
157 // Mandatory reference to dataAddr to avoid its garbage colletion
158 private final MemorySegment dataAddr;
159 final MemorySegment tensorAddr;
160
161 public Tensor(Arena arena, MemorySegment dataAddr, ElementType type, long[] shape) {
162 this(dataAddr, OnnxRuntime.getInstance().createTensor(arena, dataAddr, type, shape));
163 }
164
165 Tensor(MemorySegment dataAddr, MemorySegment tensorAddr) {
166 this.dataAddr = dataAddr;
167 this.tensorAddr = tensorAddr;
168 }
169
170 public ElementType elementType() {
171 return OnnxRuntime.getInstance().tensorElementType(tensorAddr);
172 }
173
174 public long[] shape() {
175 return OnnxRuntime.getInstance().tensorShape(tensorAddr);
176 }
177
178 public MemorySegment data() {
179 return dataAddr;
180 }
181
182 public enum ElementType {
183 FLOAT(1, float.class),
184 UINT8(2, byte.class),
185 INT8(3, byte.class),
186 UINT16(4, short.class),
187 INT16(5, short.class),
188 INT32(6, int.class),
189 INT64(7, long.class),
190 STRING(8, String.class),
191 BOOL(9, boolean.class),
192 FLOAT16(10, Object.class),
193 DOUBLE(11, double.class),
194 UINT32(12, int.class),
195 UINT64(13, long.class),
196 COMPLEX64(14, Object.class),
197 COMPLEX128(15, Object.class),
198 BFLOAT16(16, Object.class),
199 FLOAT8E4M3FN(17, Object.class),
200 FLOAT8E4M3FNUZ(18, Object.class),
201 FLOAT8E5M2(19, Object.class),
202 FLOAT8E5M2FNUZ(20, Object.class),
203 UINT4(21, Object.class),
204 INT4(22, Object.class),
205 FLOAT4E2M1(23, Object.class);
206
207 final int id;
208 final Class<?> type;
209
210 ElementType(int id, Class<?> type) {
211 this.id = id;
212 this.type = type;
213 }
214
215 public Class<?> type() {
216 return type;
217 }
218
219 public String onnxName() {
220 return name().toLowerCase();
221 }
222
223 public int bitSize() {
224 return switch (this) {
225 case INT4, UINT4, FLOAT4E2M1 -> 4;
226 case UINT8, INT8, BOOL, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ -> 8;
227 case UINT16, INT16, FLOAT16, BFLOAT16 -> 16;
228 case UINT32, INT32, FLOAT -> 32;
229 case UINT64, INT64, DOUBLE, COMPLEX64 -> 64;
230 case COMPLEX128 -> 128;
231 case STRING -> -1;
232 };
233 }
234
235 public static ElementType fromOnnxName(String name) {
236 return ElementType.valueOf(name.toUpperCase());
237 }
238
239 public static ElementType fromOnnxId(int id) {
240 return values()[id - 1];
241 }
242 }
243 }