1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.compiler;
27
28 import java.lang.invoke.MethodHandles;
29 import java.lang.reflect.GenericArrayType;
30 import java.lang.reflect.ParameterizedType;
31 import java.lang.reflect.RecordComponent;
32 import java.lang.reflect.Type;
33 import java.lang.reflect.TypeVariable;
34 import java.util.ArrayList;
35 import java.util.Arrays;
36 import java.util.HashMap;
37 import java.util.List;
38 import java.util.Map;
39 import java.util.Optional;
40 import java.util.Set;
41 import jdk.incubator.code.Op;
42 import jdk.incubator.code.TypeElement;
43 import jdk.incubator.code.Value;
44 import jdk.incubator.code.dialect.core.CoreOp;
45 import jdk.incubator.code.dialect.core.CoreType;
46 import jdk.incubator.code.dialect.core.FunctionType;
47 import jdk.incubator.code.dialect.core.TupleType;
48 import jdk.incubator.code.dialect.java.ArrayType;
49 import jdk.incubator.code.dialect.java.ClassType;
50 import jdk.incubator.code.dialect.java.JavaOp;
51 import jdk.incubator.code.dialect.java.JavaType;
52 import jdk.incubator.code.dialect.java.MethodRef;
53 import oracle.code.onnx.Tensor;
54 import oracle.code.onnx.ir.OnnxType;
55
56 public class TypeConvertor {
57
58 static final JavaType TENSOR_CLASS = JavaType.type(Tensor.class);
59
60 final MethodHandles.Lookup l;
61 final Map<String, Integer> constantArraySizeMap; // RecordComponent cannot be used as a key!
62
63 TypeConvertor(MethodHandles.Lookup l) {
64 this.l = l;
65 this.constantArraySizeMap = new HashMap<>(); // @@@ initialize
66 }
67
68 void detectConstantArrays(CoreOp.FuncOp f) {
69 f.traverse(null, (_, ce) -> {
70 if (ce instanceof JavaOp.NewOp no && no.resultType() instanceof ClassType recordType && isRecord(recordType)) {
71 Class<?> recordClass;
72 try {
73 recordClass = (Class<?>) recordType.rawType().resolve(l);
74 } catch (ReflectiveOperationException e) {
75 throw new RuntimeException(e);
76 }
77 var rcs = recordClass.getRecordComponents();
78 var ops = no.operands();
79 for (int i = 0; i < rcs.length; i++) {
80 RecordComponent rc = rcs[i];
81 Type type = rc.getGenericType();
82 if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
83 type = pt.getActualTypeArguments()[0];
84 }
85 if (type instanceof GenericArrayType) {
86 Value arr = OnnxTransformer.skipVars(ops.get(i));
87 if (arr instanceof Op.Result newArrayResult
88 && newArrayResult.op() instanceof JavaOp.NewOp newArrayOp
89 && newArrayOp.operands().getFirst() instanceof Op.Result constantResult
90 && constantResult.op() instanceof CoreOp.ConstantOp cop) {
91
92 // explicit constant array construction
93 constantArraySizeMap.put(rc.toString(), (Integer)cop.value());
94 } else {
95 // search for the highest array access index
96 scanUse(arr, rc.toString());
97 }
98 }
99 }
100 }
101 return null;
102 });
103 }
104
105 void scanUse(Value array, String rcKey) {
106 for (var use : array.uses()) {
107 if (use instanceof Op.Result or) {
108 switch (or.op()) {
109 case CoreOp.VarOp vo ->
110 scanUse(vo.result(), rcKey);
111 case CoreOp.VarAccessOp.VarLoadOp vlo ->
112 scanUse(vlo.result(), rcKey);
113 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result constR
114 && constR.op() instanceof CoreOp.ConstantOp cop ->
115 constantArraySizeMap.compute(rcKey, (_, i) -> Math.max((Integer)cop.value(), i == null ? 0 : i));
116 default -> {}
117 }
118 }
119 }
120 }
121
122 TupleType recordTypeToTupleType(ClassType recordType) {
123 Class<?> recordClass;
124 try {
125 recordClass = (Class<?>) recordType.rawType().resolve(l);
126 } catch (ReflectiveOperationException e) {
127 throw new RuntimeException(e);
128 }
129 assert recordClass.isRecord();
130
131 List<TypeElement> tupleComponentTypes = new ArrayList<>();
132 for (RecordComponent rc : recordClass.getRecordComponents()) {
133 Type type = rc.getGenericType();
134 if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
135 type = pt.getActualTypeArguments()[0];
136 }
137 switch (type) {
138 case ParameterizedType pt -> {
139 Type elementType = pt.getActualTypeArguments()[0];
140 switch (elementType) {
141 case Class<?> _ -> {
142 tupleComponentTypes.add(convertType(JavaType.type(pt)));
143 }
144 case TypeVariable<?> tv -> {
145 // Resolve type variable
146 JavaType e = null;
147 for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
148 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
149 e = recordType.typeArguments().get(j);
150 break;
151 }
152 }
153 tupleComponentTypes.add(convertType(JavaType.parameterized(JavaType.type(Tensor.class), e)));
154 }
155 default -> throw new IllegalStateException("Unexpected value: " + elementType);
156 }
157 }
158 case TypeVariable tv -> {
159 // Resolve type variable
160 JavaType e = null;
161 for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
162 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
163 e = recordType.typeArguments().get(j);
164 break;
165 }
166 }
167 tupleComponentTypes.add(convertType(e));
168 }
169 case GenericArrayType gat -> {
170 var cType = convertType(JavaType.type(gat.getGenericComponentType()));
171 Integer size = constantArraySizeMap.get(rc.toString());
172 var tContent = new TypeElement[size];
173 Arrays.fill(tContent, cType);
174 tupleComponentTypes.add(CoreType.tupleType(tContent));
175 }
176 default -> throw new IllegalStateException("Unexpected value: " + rc.getGenericType());
177 }
178 }
179
180 return CoreType.tupleType(tupleComponentTypes);
181 }
182
183 boolean isRecord(TypeElement type) {
184 try {
185 return type instanceof ClassType ct &&
186 ct.erasure().resolve(l) instanceof Class c &&
187 c.isRecord();
188 } catch (ReflectiveOperationException e) {
189 throw new RuntimeException(e);
190 }
191 }
192
193
194 Integer recordComponentAccessToTupleIndex(MethodRef ref) {
195 if (ref.refType() instanceof ClassType ct) {
196 Class<?> refClass;
197 try {
198 refClass = (Class<?>) ct.resolve(l);
199 } catch (ReflectiveOperationException e) {
200 throw new RuntimeException(e);
201 }
202
203 if (refClass.isRecord()) {
204 RecordComponent[] recordComponents = refClass.getRecordComponents();
205 for (int i = 0; i < recordComponents.length; i++) {
206 if (recordComponents[i].getName().equals(ref.name())) {
207 return i;
208 }
209 }
210 throw new InternalError();
211 }
212 }
213 return null;
214 }
215
216 FunctionType convertType(FunctionType t) {
217 return CoreType.functionType(convertType(t.returnType()), t.parameterTypes().stream().map(this::convertType).toList());
218 }
219
220 FunctionType convertType(CoreOp.FuncOp fo) {
221 return CoreType.functionType(convertType(fo.body().entryBlock().terminatingOp().operands().getFirst()), fo.parameters().stream().map(this::convertType).toList());
222 }
223
224 TypeElement convertType(Value value) {
225 // convert 1-dimensional constantly accessed constant arrays into tuples
226 if (value.type() instanceof ArrayType at && at.dimensions() == 1) {
227 int size = countConstantArraySize(value.uses());
228 if (size >= 0) {
229 var targs = new TypeElement[size];
230 Arrays.fill(targs, convertType(at.componentType()));
231 return CoreType.tupleType(targs);
232 }
233 }
234 return convertType(value.type());
235 }
236
237 static int countConstantArraySize(Set<Op.Result> uses) {
238 int size = 0;
239 for (var use : uses) {
240 int s = switch (use.op()) {
241 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co ->
242 (Integer)co.value() + 1;
243 case CoreOp.VarOp _, CoreOp.VarAccessOp.VarLoadOp _ ->
244 countConstantArraySize(use.op().result().uses());
245 default -> -1;
246 };
247 if (s < 0) return -1;
248 size = Integer.max(size, s);
249 }
250 return size;
251 }
252
253 // @@@ Map of Java tensor types to ONNX tensor types
254 // @@@ Shape??
255 TypeElement convertType(TypeElement type) {
256 if (type instanceof ClassType ct) {
257 if (ct.rawType().equals(TENSOR_CLASS)) {
258 JavaType elementType = ct.typeArguments().getFirst();
259 if (elementType.equals(JavaType.J_L_INTEGER)) {
260 return OnnxType.TENSOR_INT32;
261 } else if (elementType.equals(JavaType.J_L_FLOAT)) {
262 return OnnxType.TENSOR_FLOAT32;
263 } else if (elementType.equals(JavaType.J_L_LONG)) {
264 return OnnxType.TENSOR_INT64;
265 } else if (elementType.equals(JavaType.J_L_BYTE)) {
266 return OnnxType.TENSOR_UINT8;
267 } else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
268 return OnnxType.TENSOR_BOOL;
269 }
270 } else if (isRecord(type)) {
271 return recordTypeToTupleType(ct);
272 }
273 }
274 return type;
275 }
276 }