1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package oracle.code.onnx.compiler;
27
28 import java.lang.invoke.MethodHandles;
29 import java.lang.reflect.GenericArrayType;
30 import java.lang.reflect.ParameterizedType;
31 import java.lang.reflect.RecordComponent;
32 import java.lang.reflect.Type;
33 import java.lang.reflect.TypeVariable;
34 import java.util.ArrayList;
35 import java.util.Arrays;
36 import java.util.HashMap;
37 import java.util.List;
38 import java.util.Map;
39 import java.util.Optional;
40 import java.util.Set;
41 import jdk.incubator.code.Op;
42 import jdk.incubator.code.CodeType;
43 import jdk.incubator.code.Value;
44 import jdk.incubator.code.dialect.core.CoreOp;
45 import jdk.incubator.code.dialect.core.CoreType;
46 import jdk.incubator.code.dialect.core.FunctionType;
47 import jdk.incubator.code.dialect.core.TupleType;
48 import jdk.incubator.code.dialect.java.ArrayType;
49 import jdk.incubator.code.dialect.java.ClassType;
50 import jdk.incubator.code.dialect.java.JavaOp;
51 import jdk.incubator.code.dialect.java.JavaType;
52 import jdk.incubator.code.dialect.java.MethodRef;
53 import oracle.code.onnx.Tensor;
54 import oracle.code.onnx.ir.OnnxType;
55
56 public class TypeConvertor {
57
58 static final JavaType TENSOR_CLASS = JavaType.type(Tensor.class);
59
60 final MethodHandles.Lookup l;
61 final Map<String, Integer> constantArraySizeMap; // RecordComponent cannot be used as a key!
62
63 TypeConvertor(MethodHandles.Lookup l) {
64 this.l = l;
65 this.constantArraySizeMap = new HashMap<>(); // @@@ initialize
66 }
67
68 void detectConstantArrays(CoreOp.FuncOp f) {
69 f.elements().forEach(ce -> {
70 if (ce instanceof JavaOp.NewOp no && no.resultType() instanceof ClassType recordType && isRecord(recordType)) {
71 Class<?> recordClass;
72 try {
73 recordClass = (Class<?>) recordType.rawType().resolve(l);
74 } catch (ReflectiveOperationException e) {
75 throw new RuntimeException(e);
76 }
77 var rcs = recordClass.getRecordComponents();
78 var ops = no.operands();
79 for (int i = 0; i < rcs.length; i++) {
80 RecordComponent rc = rcs[i];
81 Type type = rc.getGenericType();
82 if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
83 type = pt.getActualTypeArguments()[0];
84 }
85 if (type instanceof GenericArrayType) {
86 Value arr = OnnxTransformer.skipVars(ops.get(i));
87 if (arr instanceof Op.Result newArrayResult
88 && newArrayResult.op() instanceof JavaOp.NewOp newArrayOp
89 && newArrayOp.operands().getFirst() instanceof Op.Result constantResult
90 && constantResult.op() instanceof CoreOp.ConstantOp cop) {
91
92 // explicit constant array construction
93 constantArraySizeMap.put(rc.toString(), (Integer)cop.value());
94 } else {
95 // search for the highest array access index
96 scanUse(arr, rc.toString());
97 }
98 }
99 }
100 }
101 });
102 }
103
104 void scanUse(Value array, String rcKey) {
105 for (var use : array.uses()) {
106 if (use instanceof Op.Result or) {
107 switch (or.op()) {
108 case CoreOp.VarOp vo ->
109 scanUse(vo.result(), rcKey);
110 case CoreOp.VarAccessOp.VarLoadOp vlo ->
111 scanUse(vlo.result(), rcKey);
112 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result constR
113 && constR.op() instanceof CoreOp.ConstantOp cop ->
114 constantArraySizeMap.compute(rcKey, (_, i) -> Math.max((Integer)cop.value(), i == null ? 0 : i));
115 default -> {}
116 }
117 }
118 }
119 }
120
121 TupleType recordTypeToTupleType(ClassType recordType) {
122 Class<?> recordClass;
123 try {
124 recordClass = (Class<?>) recordType.rawType().resolve(l);
125 } catch (ReflectiveOperationException e) {
126 throw new RuntimeException(e);
127 }
128 assert recordClass.isRecord();
129
130 List<CodeType> tupleComponentTypes = new ArrayList<>();
131 for (RecordComponent rc : recordClass.getRecordComponents()) {
132 Type type = rc.getGenericType();
133 if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
134 type = pt.getActualTypeArguments()[0];
135 }
136 switch (type) {
137 case ParameterizedType pt -> {
138 Type elementType = pt.getActualTypeArguments()[0];
139 switch (elementType) {
140 case Class<?> _ -> {
141 tupleComponentTypes.add(convertType(JavaType.type(pt)));
142 }
143 case TypeVariable<?> tv -> {
144 // Resolve type variable
145 JavaType e = null;
146 for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
147 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
148 e = recordType.typeArguments().get(j);
149 break;
150 }
151 }
152 tupleComponentTypes.add(convertType(JavaType.parameterized(JavaType.type(Tensor.class), e)));
153 }
154 default -> throw new IllegalStateException("Unexpected value: " + elementType);
155 }
156 }
157 case TypeVariable tv -> {
158 // Resolve type variable
159 JavaType e = null;
160 for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
161 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
162 e = recordType.typeArguments().get(j);
163 break;
164 }
165 }
166 tupleComponentTypes.add(convertType(e));
167 }
168 case GenericArrayType gat -> {
169 var cType = convertType(JavaType.type(gat.getGenericComponentType()));
170 Integer size = constantArraySizeMap.get(rc.toString());
171 var tContent = new CodeType[size];
172 Arrays.fill(tContent, cType);
173 tupleComponentTypes.add(CoreType.tupleType(tContent));
174 }
175 default -> throw new IllegalStateException("Unexpected value: " + rc.getGenericType());
176 }
177 }
178
179 return CoreType.tupleType(tupleComponentTypes);
180 }
181
182 boolean isRecord(CodeType type) {
183 try {
184 return type instanceof ClassType ct &&
185 ct.erasure().resolve(l) instanceof Class c &&
186 c.isRecord();
187 } catch (ReflectiveOperationException e) {
188 throw new RuntimeException(e);
189 }
190 }
191
192
193 Integer recordComponentAccessToTupleIndex(MethodRef ref) {
194 if (ref.refType() instanceof ClassType ct) {
195 Class<?> refClass;
196 try {
197 refClass = (Class<?>) ct.resolve(l);
198 } catch (ReflectiveOperationException e) {
199 throw new RuntimeException(e);
200 }
201
202 if (refClass.isRecord()) {
203 RecordComponent[] recordComponents = refClass.getRecordComponents();
204 for (int i = 0; i < recordComponents.length; i++) {
205 if (recordComponents[i].getName().equals(ref.name())) {
206 return i;
207 }
208 }
209 throw new InternalError();
210 }
211 }
212 return null;
213 }
214
215 FunctionType convertType(FunctionType t) {
216 return CoreType.functionType(convertType(t.returnType()), t.parameterTypes().stream().map(this::convertType).toList());
217 }
218
219 FunctionType convertType(CoreOp.FuncOp fo) {
220 return CoreType.functionType(convertType(fo.body().entryBlock().terminatingOp().operands().getFirst()), fo.parameters().stream().map(this::convertType).toList());
221 }
222
223 CodeType convertType(Value value) {
224 // convert 1-dimensional constantly accessed constant arrays into tuples
225 if (value.type() instanceof ArrayType at && at.dimensions() == 1) {
226 int size = countConstantArraySize(value.uses());
227 if (size >= 0) {
228 var targs = new CodeType[size];
229 Arrays.fill(targs, convertType(at.componentType()));
230 return CoreType.tupleType(targs);
231 }
232 }
233 return convertType(value.type());
234 }
235
236 static int countConstantArraySize(Set<Op.Result> uses) {
237 int size = 0;
238 for (var use : uses) {
239 int s = switch (use.op()) {
240 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co ->
241 (Integer)co.value() + 1;
242 case CoreOp.VarOp _, CoreOp.VarAccessOp.VarLoadOp _ ->
243 countConstantArraySize(use.op().result().uses());
244 default -> -1;
245 };
246 if (s < 0) return -1;
247 size = Integer.max(size, s);
248 }
249 return size;
250 }
251
252 // @@@ Map of Java tensor types to ONNX tensor types
253 // @@@ Shape??
254 CodeType convertType(CodeType type) {
255 if (type instanceof ClassType ct) {
256 if (ct.rawType().equals(TENSOR_CLASS)) {
257 JavaType elementType = ct.typeArguments().getFirst();
258 if (elementType.equals(JavaType.J_L_INTEGER)) {
259 return OnnxType.TENSOR_INT32;
260 } else if (elementType.equals(JavaType.J_L_FLOAT)) {
261 return OnnxType.TENSOR_FLOAT32;
262 } else if (elementType.equals(JavaType.J_L_LONG)) {
263 return OnnxType.TENSOR_INT64;
264 } else if (elementType.equals(JavaType.J_L_BYTE)) {
265 return OnnxType.TENSOR_UINT8;
266 } else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
267 return OnnxType.TENSOR_BOOL;
268 } else if (elementType.equals(JavaType.J_L_STRING)) {
269 return OnnxType.TENSOR_STRING;
270 }
271 } else if (isRecord(type)) {
272 return recordTypeToTupleType(ct);
273 }
274 }
275 return type;
276 }
277 }