1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.compiler;
 27 
 28 import java.lang.invoke.MethodHandles;
 29 import java.lang.reflect.GenericArrayType;
 30 import java.lang.reflect.ParameterizedType;
 31 import java.lang.reflect.RecordComponent;
 32 import java.lang.reflect.Type;
 33 import java.lang.reflect.TypeVariable;
 34 import java.util.ArrayList;
 35 import java.util.Arrays;
 36 import java.util.HashMap;
 37 import java.util.List;
 38 import java.util.Map;
 39 import java.util.Optional;
 40 import java.util.Set;
 41 import jdk.incubator.code.Op;
 42 import jdk.incubator.code.CodeType;
 43 import jdk.incubator.code.Value;
 44 import jdk.incubator.code.dialect.core.CoreOp;
 45 import jdk.incubator.code.dialect.core.CoreType;
 46 import jdk.incubator.code.dialect.core.FunctionType;
 47 import jdk.incubator.code.dialect.core.TupleType;
 48 import jdk.incubator.code.dialect.java.ArrayType;
 49 import jdk.incubator.code.dialect.java.ClassType;
 50 import jdk.incubator.code.dialect.java.JavaOp;
 51 import jdk.incubator.code.dialect.java.JavaType;
 52 import jdk.incubator.code.dialect.java.MethodRef;
 53 import oracle.code.onnx.Tensor;
 54 import oracle.code.onnx.ir.OnnxType;
 55 
 56 public class TypeConvertor {
 57 
 58     static final JavaType TENSOR_CLASS = JavaType.type(Tensor.class);
 59 
 60     final MethodHandles.Lookup l;
 61     final Map<String, Integer> constantArraySizeMap; // RecordComponent cannot be used as a key!
 62 
 63     TypeConvertor(MethodHandles.Lookup l) {
 64         this.l = l;
 65         this.constantArraySizeMap = new HashMap<>(); // @@@ initialize
 66     }
 67 
 68     void detectConstantArrays(CoreOp.FuncOp f) {
 69         f.elements().forEach(ce -> {
 70             if (ce instanceof JavaOp.NewOp no && no.resultType() instanceof ClassType recordType && isRecord(recordType)) {
 71                 Class<?> recordClass;
 72                 try {
 73                     recordClass = (Class<?>) recordType.rawType().resolve(l);
 74                 } catch (ReflectiveOperationException e) {
 75                     throw new RuntimeException(e);
 76                 }
 77                 var rcs = recordClass.getRecordComponents();
 78                 var ops = no.operands();
 79                 for (int i = 0; i < rcs.length; i++) {
 80                     RecordComponent rc  = rcs[i];
 81                     Type type = rc.getGenericType();
 82                     if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
 83                         type = pt.getActualTypeArguments()[0];
 84                     }
 85                     if (type instanceof GenericArrayType) {
 86                         Value arr = OnnxTransformer.skipVars(ops.get(i));
 87                         if (arr instanceof Op.Result newArrayResult
 88                                 && newArrayResult.op() instanceof JavaOp.NewOp newArrayOp
 89                                 && newArrayOp.operands().getFirst() instanceof Op.Result constantResult
 90                                 && constantResult.op() instanceof CoreOp.ConstantOp cop) {
 91 
 92                             // explicit constant array construction
 93                             constantArraySizeMap.put(rc.toString(), (Integer)cop.value());
 94                         } else {
 95                             // search for the highest array access index
 96                             scanUse(arr, rc.toString());
 97                         }
 98                     }
 99                 }
100             }
101         });
102     }
103 
104     void scanUse(Value array, String rcKey) {
105         for (var use : array.uses()) {
106             if (use instanceof Op.Result or) {
107                 switch (or.op()) {
108                     case CoreOp.VarOp vo ->
109                         scanUse(vo.result(), rcKey);
110                     case CoreOp.VarAccessOp.VarLoadOp vlo ->
111                         scanUse(vlo.result(), rcKey);
112                     case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result constR
113                                                     && constR.op() instanceof CoreOp.ConstantOp cop ->
114                         constantArraySizeMap.compute(rcKey, (_, i) -> Math.max((Integer)cop.value(), i == null ? 0 : i));
115                     default -> {}
116                 }
117             }
118         }
119     }
120 
121     TupleType recordTypeToTupleType(ClassType recordType) {
122         Class<?> recordClass;
123         try {
124             recordClass = (Class<?>) recordType.rawType().resolve(l);
125         } catch (ReflectiveOperationException e) {
126             throw new RuntimeException(e);
127         }
128         assert recordClass.isRecord();
129 
130         List<CodeType> tupleComponentTypes = new ArrayList<>();
131         for (RecordComponent rc : recordClass.getRecordComponents()) {
132             Type type = rc.getGenericType();
133             if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
134                 type = pt.getActualTypeArguments()[0];
135             }
136             switch (type) {
137                 case ParameterizedType pt -> {
138                     Type elementType = pt.getActualTypeArguments()[0];
139                     switch (elementType) {
140                         case Class<?> _ -> {
141                             tupleComponentTypes.add(convertType(JavaType.type(pt)));
142                         }
143                         case TypeVariable<?> tv -> {
144                             // Resolve type variable
145                             JavaType e = null;
146                             for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
147                                 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
148                                     e = recordType.typeArguments().get(j);
149                                     break;
150                                 }
151                             }
152                             tupleComponentTypes.add(convertType(JavaType.parameterized(JavaType.type(Tensor.class), e)));
153                         }
154                         default -> throw new IllegalStateException("Unexpected value: " + elementType);
155                     }
156                 }
157                 case TypeVariable tv -> {
158                     // Resolve type variable
159                     JavaType e = null;
160                     for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
161                         if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
162                             e = recordType.typeArguments().get(j);
163                             break;
164                         }
165                     }
166                     tupleComponentTypes.add(convertType(e));
167                 }
168                 case GenericArrayType gat -> {
169                     var cType = convertType(JavaType.type(gat.getGenericComponentType()));
170                     Integer size = constantArraySizeMap.get(rc.toString());
171                     var tContent = new CodeType[size];
172                     Arrays.fill(tContent, cType);
173                     tupleComponentTypes.add(CoreType.tupleType(tContent));
174                 }
175                 default -> throw new IllegalStateException("Unexpected value: " + rc.getGenericType());
176             }
177         }
178 
179         return CoreType.tupleType(tupleComponentTypes);
180     }
181 
182     boolean isRecord(CodeType type) {
183         try {
184             return type instanceof ClassType ct &&
185                     ct.erasure().resolve(l) instanceof Class c &&
186                     c.isRecord();
187         } catch (ReflectiveOperationException e) {
188             throw new RuntimeException(e);
189         }
190     }
191 
192 
193     Integer recordComponentAccessToTupleIndex(MethodRef ref) {
194         if (ref.refType() instanceof ClassType ct) {
195             Class<?> refClass;
196             try {
197                 refClass = (Class<?>) ct.resolve(l);
198             } catch (ReflectiveOperationException e) {
199                 throw new RuntimeException(e);
200             }
201 
202             if (refClass.isRecord()) {
203                 RecordComponent[] recordComponents = refClass.getRecordComponents();
204                 for (int i = 0; i < recordComponents.length; i++) {
205                     if (recordComponents[i].getName().equals(ref.name())) {
206                         return i;
207                     }
208                 }
209                 throw new InternalError();
210             }
211         }
212         return null;
213     }
214 
215     FunctionType convertType(FunctionType t) {
216         return CoreType.functionType(convertType(t.returnType()), t.parameterTypes().stream().map(this::convertType).toList());
217     }
218 
219     FunctionType convertType(CoreOp.FuncOp fo) {
220         return CoreType.functionType(convertType(fo.body().entryBlock().terminatingOp().operands().getFirst()), fo.parameters().stream().map(this::convertType).toList());
221     }
222 
223     CodeType convertType(Value value) {
224         // convert 1-dimensional constantly accessed constant arrays into tuples
225         if (value.type() instanceof ArrayType at && at.dimensions() == 1) {
226             int size = countConstantArraySize(value.uses());
227             if (size >= 0) {
228                 var targs = new CodeType[size];
229                 Arrays.fill(targs, convertType(at.componentType()));
230                 return CoreType.tupleType(targs);
231             }
232         }
233         return convertType(value.type());
234     }
235 
236     static int countConstantArraySize(Set<Op.Result> uses) {
237         int size = 0;
238         for (var use : uses) {
239             int s = switch (use.op()) {
240                 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co ->
241                     (Integer)co.value() + 1;
242                 case CoreOp.VarOp _, CoreOp.VarAccessOp.VarLoadOp _ ->
243                     countConstantArraySize(use.op().result().uses());
244                 default -> -1;
245             };
246             if (s < 0) return -1;
247             size = Integer.max(size, s);
248         }
249         return size;
250     }
251 
252     // @@@ Map of Java tensor types to ONNX tensor types
253     // @@@ Shape??
254     CodeType convertType(CodeType type) {
255         if (type instanceof ClassType ct) {
256             if (ct.rawType().equals(TENSOR_CLASS)) {
257                 JavaType elementType = ct.typeArguments().getFirst();
258                 if (elementType.equals(JavaType.J_L_INTEGER)) {
259                     return OnnxType.TENSOR_INT32;
260                 } else if (elementType.equals(JavaType.J_L_FLOAT)) {
261                     return OnnxType.TENSOR_FLOAT32;
262                 } else if (elementType.equals(JavaType.J_L_LONG)) {
263                     return OnnxType.TENSOR_INT64;
264                 } else if (elementType.equals(JavaType.J_L_BYTE)) {
265                     return OnnxType.TENSOR_UINT8;
266                 } else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
267                     return OnnxType.TENSOR_BOOL;
268                 } else if (elementType.equals(JavaType.J_L_STRING)) {
269                     return OnnxType.TENSOR_STRING;
270                 }
271             } else if (isRecord(type)) {
272                 return recordTypeToTupleType(ct);
273             }
274         }
275         return type;
276     }
277 }