1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.Arena;
 29 import java.lang.foreign.MemorySegment;
 30 import java.lang.foreign.ValueLayout;
 31 import java.util.List;
 32 import java.util.stream.LongStream;
 33 
 34     /*
 35 class DataType(enum.IntEnum):
 36     """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
 37 
 38     # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
 39     # but we should stick to the names used in the ONNX spec for consistency.
 40     UNDEFINED = 0
 41     FLOAT = 1
 42     UINT8 = 2
 43     INT8 = 3
 44     UINT16 = 4
 45     INT16 = 5
 46     INT32 = 6
 47     INT64 = 7
 48     STRING = 8
 49     BOOL = 9
 50     FLOAT16 = 10
 51     DOUBLE = 11
 52     UINT32 = 12
 53     UINT64 = 13
 54     COMPLEX64 = 14
 55     COMPLEX128 = 15
 56     BFLOAT16 = 16
 57     FLOAT8E4M3FN = 17
 58     FLOAT8E4M3FNUZ = 18
 59     FLOAT8E5M2 = 19
 60     FLOAT8E5M2FNUZ = 20
 61     UINT4 = 21
 62     INT4 = 22
 63     FLOAT4E2M1 = 23
 64      */
 65 
 66 public class Tensor<T> extends OnnxNumber {
 67 
 68     public static final long[] SCALAR_SHAPE = new long[0];
 69 
 70     public static Tensor<Boolean> ofScalar(boolean b) {
 71         return ofScalar(Arena.ofAuto(), b);
 72     }
 73 
 74     public static Tensor<Boolean> ofScalar(Arena arena, boolean b) {
 75         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, b ? (byte)1 : 0), ElementType.BOOL, SCALAR_SHAPE);
 76     }
 77 
 78     public static Tensor<Byte> ofScalar(byte b) {
 79         return ofShape(SCALAR_SHAPE, b);
 80     }
 81 
 82     public static Tensor<Byte> ofScalar(Arena arena, byte b) {
 83         return ofShape(arena, SCALAR_SHAPE, b);
 84     }
 85 
 86     public static Tensor<Long> ofScalar(long l) {
 87         return ofShape(SCALAR_SHAPE, l);
 88     }
 89 
 90     public static Tensor<Long> ofScalar(Arena arena, long l) {
 91         return ofShape(arena, SCALAR_SHAPE, l);
 92     }
 93 
 94     public static Tensor<Float> ofScalar(float f) {
 95         return ofShape(SCALAR_SHAPE, f);
 96     }
 97 
 98     public static Tensor<Float> ofScalar(Arena arena, float f) {
 99         return ofShape(arena, SCALAR_SHAPE, f);
100     }
101 
102     public static Tensor<Byte> ofFlat(byte... values) {
103         return ofShape(new long[]{values.length}, values);
104     }
105 
106     public static Tensor<Byte> ofFlat(Arena arena, byte... values) {
107         return ofShape(arena, new long[]{values.length}, values);
108     }
109 
110     public static Tensor<Integer> ofFlat(int... values) {
111         return ofShape(new long[]{values.length}, values);
112     }
113 
114     public static Tensor<Integer> ofFlat(Arena arena, int... values) {
115         return ofShape(arena, new long[]{values.length}, values);
116     }
117 
118     public static Tensor<Long> ofFlat(long... values) {
119         return ofShape(new long[]{values.length}, values);
120     }
121 
122     public static Tensor<Long> ofFlat(Arena arena, long... values) {
123         return ofShape(arena, new long[]{values.length}, values);
124     }
125 
126     public static Tensor<Float> ofFlat(float... values) {
127         return ofShape(new long[]{values.length}, values);
128     }
129 
130     public static Tensor<Float> ofFlat(Arena arena, float... values) {
131         return ofShape(arena, new long[]{values.length}, values);
132     }
133 
134     public static Tensor<String> ofFlat(String... s) {
135         return ofShape(new long[]{s.length}, s);
136     }
137 
138     public static Tensor<String> ofFlat(Arena arena, String... s) {
139         return ofShape(arena, new long[]{s.length}, s);
140     }
141 
142     public static Tensor<Byte> ofShape(long[] shape, byte... values) {
143         return ofShape(Arena.ofAuto(), shape, values);
144     }
145 
146     public static Tensor<Byte> ofShape(Arena arena, long[] shape, byte... values) {
147         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, values), ElementType.UINT8, shape);
148     }
149 
150     public static Tensor<Integer> ofShape(long[] shape, int... values) {
151         return ofShape(Arena.ofAuto(), shape, values);
152     }
153 
154     public static Tensor<Integer> ofShape(Arena arena, long[] shape, int... values) {
155         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_INT, values), ElementType.INT32, shape);
156     }
157 
158     public static Tensor<Long> ofShape(long[] shape, long... values) {
159         return ofShape(Arena.ofAuto(), shape, values);
160     }
161 
162     public static Tensor<Long> ofShape(Arena arena, long[] shape, long... values) {
163         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_LONG, values), ElementType.INT64, shape);
164     }
165 
166     public static Tensor<Float> ofShape(long[] shape, float... values) {
167         return ofShape(Arena.ofAuto(), shape, values);
168     }
169 
170     public static Tensor<Float> ofShape(Arena arena, long[] shape, float... values) {
171         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_FLOAT, values), ElementType.FLOAT, shape);
172     }
173 
174     public static <T> Tensor<T> ofShape(long[] shape, byte[] rawData, ElementType elementType) {
175         return ofShape(Arena.ofAuto(), shape, rawData, elementType);
176     }
177 
178     public static <T> Tensor<T> ofShape(Arena arena, long[] shape, byte[] rawData, ElementType elementType) {
179         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, rawData), elementType, shape);
180     }
181 
182     public static Tensor<String> ofShape(long[] shape, String... values) {
183         return ofShape(Arena.ofAuto(), shape, values);
184     }
185 
186     public static Tensor<String> ofShape(Arena arena, long[] shape, String... values) {
187         MemorySegment tensorAddr = OnnxRuntime.getInstance().createStringTensor(arena, values, shape);
188         return new Tensor(OnnxRuntime.getInstance().tensorData(tensorAddr), tensorAddr);
189     }
190 
191     // Mandatory reference to dataAddr to avoid its garbage colletion
192     private final MemorySegment dataAddr;
193     final MemorySegment tensorAddr;
194 
195     public Tensor(Arena arena, MemorySegment dataAddr, ElementType type, long[] shape) {
196         this(dataAddr, OnnxRuntime.getInstance().createTensor(arena, dataAddr, type, shape));
197     }
198 
199     Tensor(MemorySegment dataAddr, MemorySegment tensorAddr) {
200         this.dataAddr = dataAddr;
201         this.tensorAddr = tensorAddr;
202     }
203 
204     public ElementType elementType() {
205         return OnnxRuntime.getInstance().tensorElementType(tensorAddr);
206     }
207 
208     public long[] shape() {
209         return OnnxRuntime.getInstance().tensorShape(tensorAddr);
210     }
211 
212     public MemorySegment data() {
213         return dataAddr;
214     }
215 
216     public List<String> dataAsStrings() {
217         if (elementType() != ElementType.STRING) {
218             throw new IllegalStateException("Not a String tensor.");
219         }
220         return LongStream.range(0, OnnxRuntime.getInstance().tensorShapeElementCount(tensorAddr))
221                 .mapToObj(i -> OnnxRuntime.getInstance().stringTensorElement(tensorAddr, i)).toList();
222     }
223 
224     public enum ElementType {
225         FLOAT(1, float.class),
226         UINT8(2, byte.class),
227         INT8(3, byte.class),
228         UINT16(4, short.class),
229         INT16(5, short.class),
230         INT32(6, int.class),
231         INT64(7, long.class),
232         STRING(8, String.class),
233         BOOL(9, boolean.class),
234         FLOAT16(10, Object.class),
235         DOUBLE(11, double.class),
236         UINT32(12, int.class),
237         UINT64(13, long.class),
238         COMPLEX64(14, Object.class),
239         COMPLEX128(15, Object.class),
240         BFLOAT16(16, Object.class),
241         FLOAT8E4M3FN(17, Object.class),
242         FLOAT8E4M3FNUZ(18, Object.class),
243         FLOAT8E5M2(19, Object.class),
244         FLOAT8E5M2FNUZ(20, Object.class),
245         UINT4(21, Object.class),
246         INT4(22, Object.class),
247         FLOAT4E2M1(23, Object.class);
248 
249         final int id;
250         final Class<?> type;
251 
252         ElementType(int id, Class<?> type) {
253             this.id = id;
254             this.type = type;
255         }
256 
257         public Class<?> type() {
258             return type;
259         }
260 
261         public String onnxName() {
262             return name().toLowerCase();
263         }
264 
265         public int bitSize() {
266             return switch (this) {
267                 case INT4, UINT4, FLOAT4E2M1 -> 4;
268                 case UINT8, INT8, BOOL, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ -> 8;
269                 case UINT16, INT16, FLOAT16, BFLOAT16 -> 16;
270                 case UINT32, INT32, FLOAT -> 32;
271                 case UINT64, INT64, DOUBLE, COMPLEX64 -> 64;
272                 case COMPLEX128 -> 128;
273                 case STRING -> 192; // @@@ a magic number?
274             };
275         }
276 
277         public static ElementType fromOnnxName(String name) {
278             return ElementType.valueOf(name.toUpperCase());
279         }
280 
281         public static ElementType fromOnnxId(int id) {
282             return values()[id - 1];
283         }
284     }
285 }