1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx;
 27 
 28 import java.lang.foreign.Arena;
 29 import java.lang.foreign.MemorySegment;
 30 import java.lang.foreign.ValueLayout;
 31 
 32     /*
 33 class DataType(enum.IntEnum):
 34     """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
 35 
 36     # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
 37     # but we should stick to the names used in the ONNX spec for consistency.
 38     UNDEFINED = 0
 39     FLOAT = 1
 40     UINT8 = 2
 41     INT8 = 3
 42     UINT16 = 4
 43     INT16 = 5
 44     INT32 = 6
 45     INT64 = 7
 46     STRING = 8
 47     BOOL = 9
 48     FLOAT16 = 10
 49     DOUBLE = 11
 50     UINT32 = 12
 51     UINT64 = 13
 52     COMPLEX64 = 14
 53     COMPLEX128 = 15
 54     BFLOAT16 = 16
 55     FLOAT8E4M3FN = 17
 56     FLOAT8E4M3FNUZ = 18
 57     FLOAT8E5M2 = 19
 58     FLOAT8E5M2FNUZ = 20
 59     UINT4 = 21
 60     INT4 = 22
 61     FLOAT4E2M1 = 23
 62      */
 63 
 64 public class Tensor<T> extends OnnxNumber {
 65 
 66     public static final long[] SCALAR_SHAPE = new long[0];
 67 
 68     public static Tensor<Boolean> ofScalar(boolean b) {
 69         return ofScalar(Arena.ofAuto(), b);
 70     }
 71 
 72     public static Tensor<Boolean> ofScalar(Arena arena, boolean b) {
 73         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, b ? (byte)1 : 0), ElementType.BOOL, SCALAR_SHAPE);
 74     }
 75 
 76     public static Tensor<Byte> ofScalar(byte b) {
 77         return ofShape(SCALAR_SHAPE, b);
 78     }
 79 
 80     public static Tensor<Byte> ofScalar(Arena arena, byte b) {
 81         return ofShape(arena, SCALAR_SHAPE, b);
 82     }
 83 
 84     public static Tensor<Long> ofScalar(long l) {
 85         return ofShape(SCALAR_SHAPE, l);
 86     }
 87 
 88     public static Tensor<Long> ofScalar(Arena arena, long l) {
 89         return ofShape(arena, SCALAR_SHAPE, l);
 90     }
 91 
 92     public static Tensor<Float> ofScalar(float f) {
 93         return ofShape(SCALAR_SHAPE, f);
 94     }
 95 
 96     public static Tensor<Float> ofScalar(Arena arena, float f) {
 97         return ofShape(arena, SCALAR_SHAPE, f);
 98     }
 99 
100 
101     public static Tensor<Byte> ofFlat(byte... values) {
102         return ofShape(new long[]{values.length}, values);
103     }
104 
105     public static Tensor<Byte> ofFlat(Arena arena, byte... values) {
106         return ofShape(arena, new long[]{values.length}, values);
107     }
108 
109     public static Tensor<Long> ofFlat(long... values) {
110         return ofShape(new long[]{values.length}, values);
111     }
112 
113     public static Tensor<Long> ofFlat(Arena arena, long... values) {
114         return ofShape(arena, new long[]{values.length}, values);
115     }
116 
117     public static Tensor<Float> ofFlat(float... values) {
118         return ofShape(new long[]{values.length}, values);
119     }
120 
121     public static Tensor<Float> ofFlat(Arena arena, float... values) {
122         return ofShape(arena, new long[]{values.length}, values);
123     }
124 
125     public static Tensor<Byte> ofShape(long[] shape, byte... values) {
126         return ofShape(Arena.ofAuto(), shape, values);
127     }
128 
129     public static Tensor<Byte> ofShape(Arena arena, long[] shape, byte... values) {
130         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, values), ElementType.UINT8, shape);
131     }
132 
133     public static Tensor<Long> ofShape(long[] shape, long... values) {
134         return ofShape(Arena.ofAuto(), shape, values);
135     }
136 
137     public static Tensor<Long> ofShape(Arena arena, long[] shape, long... values) {
138         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_LONG, values), ElementType.INT64, shape);
139     }
140 
141     public static Tensor<Float> ofShape(long[] shape, float... values) {
142         return ofShape(Arena.ofAuto(), shape, values);
143     }
144 
145     public static Tensor<Float> ofShape(Arena arena, long[] shape, float... values) {
146         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_FLOAT, values), ElementType.FLOAT, shape);
147     }
148 
149     public static <T> Tensor<T> ofShape(long[] shape, byte[] rawData, ElementType elementType) {
150         return ofShape(Arena.ofAuto(), shape, rawData, elementType);
151     }
152 
153     public static <T> Tensor<T> ofShape(Arena arena, long[] shape, byte[] rawData, ElementType elementType) {
154         return new Tensor(arena, arena.allocateFrom(ValueLayout.JAVA_BYTE, rawData), elementType, shape);
155     }
156 
157     // Mandatory reference to dataAddr to avoid its garbage colletion
158     private final MemorySegment dataAddr;
159     final MemorySegment tensorAddr;
160 
161     public Tensor(Arena arena, MemorySegment dataAddr, ElementType type, long[] shape) {
162         this(dataAddr, OnnxRuntime.getInstance().createTensor(arena, dataAddr, type, shape));
163     }
164 
165     Tensor(MemorySegment dataAddr, MemorySegment tensorAddr) {
166         this.dataAddr = dataAddr;
167         this.tensorAddr = tensorAddr;
168     }
169 
170     public ElementType elementType() {
171         return OnnxRuntime.getInstance().tensorElementType(tensorAddr);
172     }
173 
174     public long[] shape() {
175         return OnnxRuntime.getInstance().tensorShape(tensorAddr);
176     }
177 
178     public MemorySegment data() {
179         return dataAddr;
180     }
181 
182     public enum ElementType {
183         FLOAT(1, float.class),
184         UINT8(2, byte.class),
185         INT8(3, byte.class),
186         UINT16(4, short.class),
187         INT16(5, short.class),
188         INT32(6, int.class),
189         INT64(7, long.class),
190         STRING(8, String.class),
191         BOOL(9, boolean.class),
192         FLOAT16(10, Object.class),
193         DOUBLE(11, double.class),
194         UINT32(12, int.class),
195         UINT64(13, long.class),
196         COMPLEX64(14, Object.class),
197         COMPLEX128(15, Object.class),
198         BFLOAT16(16, Object.class),
199         FLOAT8E4M3FN(17, Object.class),
200         FLOAT8E4M3FNUZ(18, Object.class),
201         FLOAT8E5M2(19, Object.class),
202         FLOAT8E5M2FNUZ(20, Object.class),
203         UINT4(21, Object.class),
204         INT4(22, Object.class),
205         FLOAT4E2M1(23, Object.class);
206 
207         final int id;
208         final Class<?> type;
209 
210         ElementType(int id, Class<?> type) {
211             this.id = id;
212             this.type = type;
213         }
214 
215         public Class<?> type() {
216             return type;
217         }
218 
219         public String onnxName() {
220             return name().toLowerCase();
221         }
222 
223         public int bitSize() {
224             return switch (this) {
225                 case INT4, UINT4, FLOAT4E2M1 -> 4;
226                 case UINT8, INT8, BOOL, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ -> 8;
227                 case UINT16, INT16, FLOAT16, BFLOAT16 -> 16;
228                 case UINT32, INT32, FLOAT -> 32;
229                 case UINT64, INT64, DOUBLE, COMPLEX64 -> 64;
230                 case COMPLEX128 -> 128;
231                 case STRING -> -1;
232             };
233         }
234 
235         public static ElementType fromOnnxName(String name) {
236             return ElementType.valueOf(name.toUpperCase());
237         }
238 
239         public static ElementType fromOnnxId(int id) {
240             return values()[id - 1];
241         }
242     }
243 }