1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package oracle.code.onnx.compiler;
 27 
 28 import java.lang.invoke.MethodHandles;
 29 import java.lang.reflect.GenericArrayType;
 30 import java.lang.reflect.ParameterizedType;
 31 import java.lang.reflect.RecordComponent;
 32 import java.lang.reflect.Type;
 33 import java.lang.reflect.TypeVariable;
 34 import java.util.ArrayList;
 35 import java.util.Arrays;
 36 import java.util.HashMap;
 37 import java.util.List;
 38 import java.util.Map;
 39 import java.util.Optional;
 40 import java.util.Set;
 41 import jdk.incubator.code.Op;
 42 import jdk.incubator.code.TypeElement;
 43 import jdk.incubator.code.Value;
 44 import jdk.incubator.code.dialect.core.CoreOp;
 45 import jdk.incubator.code.dialect.core.CoreType;
 46 import jdk.incubator.code.dialect.core.FunctionType;
 47 import jdk.incubator.code.dialect.core.TupleType;
 48 import jdk.incubator.code.dialect.java.ArrayType;
 49 import jdk.incubator.code.dialect.java.ClassType;
 50 import jdk.incubator.code.dialect.java.JavaOp;
 51 import jdk.incubator.code.dialect.java.JavaType;
 52 import jdk.incubator.code.dialect.java.MethodRef;
 53 import oracle.code.onnx.Tensor;
 54 import oracle.code.onnx.ir.OnnxType;
 55 
 56 public class TypeConvertor {
 57 
 58     static final JavaType TENSOR_CLASS = JavaType.type(Tensor.class);
 59 
 60     final MethodHandles.Lookup l;
 61     final Map<String, Integer> constantArraySizeMap; // RecordComponent cannot be used as a key!
 62 
 63     TypeConvertor(MethodHandles.Lookup l) {
 64         this.l = l;
 65         this.constantArraySizeMap = new HashMap<>(); // @@@ initialize
 66     }
 67 
 68     void detectConstantArrays(CoreOp.FuncOp f) {
 69         f.traverse(null, (_, ce) -> {
 70             if (ce instanceof JavaOp.NewOp no && no.resultType() instanceof ClassType recordType && isRecord(recordType)) {
 71                 Class<?> recordClass;
 72                 try {
 73                     recordClass = (Class<?>) recordType.rawType().resolve(l);
 74                 } catch (ReflectiveOperationException e) {
 75                     throw new RuntimeException(e);
 76                 }
 77                 var rcs = recordClass.getRecordComponents();
 78                 var ops = no.operands();
 79                 for (int i = 0; i < rcs.length; i++) {
 80                     RecordComponent rc  = rcs[i];
 81                     Type type = rc.getGenericType();
 82                     if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
 83                         type = pt.getActualTypeArguments()[0];
 84                     }
 85                     if (type instanceof GenericArrayType) {
 86                         Value arr = OnnxTransformer.skipVars(ops.get(i));
 87                         if (arr instanceof Op.Result newArrayResult
 88                                 && newArrayResult.op() instanceof JavaOp.NewOp newArrayOp
 89                                 && newArrayOp.operands().getFirst() instanceof Op.Result constantResult
 90                                 && constantResult.op() instanceof CoreOp.ConstantOp cop) {
 91 
 92                             // explicit constant array construction
 93                             constantArraySizeMap.put(rc.toString(), (Integer)cop.value());
 94                         } else {
 95                             // search for the highest array access index
 96                             scanUse(arr, rc.toString());
 97                         }
 98                     }
 99                 }
100             }
101             return null;
102         });
103     }
104 
105     void scanUse(Value array, String rcKey) {
106         for (var use : array.uses()) {
107             if (use instanceof Op.Result or) {
108                 switch (or.op()) {
109                     case CoreOp.VarOp vo ->
110                         scanUse(vo.result(), rcKey);
111                     case CoreOp.VarAccessOp.VarLoadOp vlo ->
112                         scanUse(vlo.result(), rcKey);
113                     case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result constR
114                                                     && constR.op() instanceof CoreOp.ConstantOp cop ->
115                         constantArraySizeMap.compute(rcKey, (_, i) -> Math.max((Integer)cop.value(), i == null ? 0 : i));
116                     default -> {}
117                 }
118             }
119         }
120     }
121 
122     TupleType recordTypeToTupleType(ClassType recordType) {
123         Class<?> recordClass;
124         try {
125             recordClass = (Class<?>) recordType.rawType().resolve(l);
126         } catch (ReflectiveOperationException e) {
127             throw new RuntimeException(e);
128         }
129         assert recordClass.isRecord();
130 
131         List<TypeElement> tupleComponentTypes = new ArrayList<>();
132         for (RecordComponent rc : recordClass.getRecordComponents()) {
133             Type type = rc.getGenericType();
134             if (type instanceof ParameterizedType pt && pt.getRawType().equals(Optional.class)) {
135                 type = pt.getActualTypeArguments()[0];
136             }
137             switch (type) {
138                 case ParameterizedType pt -> {
139                     Type elementType = pt.getActualTypeArguments()[0];
140                     switch (elementType) {
141                         case Class<?> _ -> {
142                             tupleComponentTypes.add(convertType(JavaType.type(pt)));
143                         }
144                         case TypeVariable<?> tv -> {
145                             // Resolve type variable
146                             JavaType e = null;
147                             for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
148                                 if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
149                                     e = recordType.typeArguments().get(j);
150                                     break;
151                                 }
152                             }
153                             tupleComponentTypes.add(convertType(JavaType.parameterized(JavaType.type(Tensor.class), e)));
154                         }
155                         default -> throw new IllegalStateException("Unexpected value: " + elementType);
156                     }
157                 }
158                 case TypeVariable tv -> {
159                     // Resolve type variable
160                     JavaType e = null;
161                     for (int j = 0; j < recordClass.getTypeParameters().length; j++) {
162                         if (recordClass.getTypeParameters()[j].getName().equals(tv.getName())) {
163                             e = recordType.typeArguments().get(j);
164                             break;
165                         }
166                     }
167                     tupleComponentTypes.add(convertType(e));
168                 }
169                 case GenericArrayType gat -> {
170                     var cType = convertType(JavaType.type(gat.getGenericComponentType()));
171                     Integer size = constantArraySizeMap.get(rc.toString());
172                     var tContent = new TypeElement[size];
173                     Arrays.fill(tContent, cType);
174                     tupleComponentTypes.add(CoreType.tupleType(tContent));
175                 }
176                 default -> throw new IllegalStateException("Unexpected value: " + rc.getGenericType());
177             }
178         }
179 
180         return CoreType.tupleType(tupleComponentTypes);
181     }
182 
183     boolean isRecord(TypeElement type) {
184         try {
185             return type instanceof ClassType ct &&
186                     ct.erasure().resolve(l) instanceof Class c &&
187                     c.isRecord();
188         } catch (ReflectiveOperationException e) {
189             throw new RuntimeException(e);
190         }
191     }
192 
193 
194     Integer recordComponentAccessToTupleIndex(MethodRef ref) {
195         if (ref.refType() instanceof ClassType ct) {
196             Class<?> refClass;
197             try {
198                 refClass = (Class<?>) ct.resolve(l);
199             } catch (ReflectiveOperationException e) {
200                 throw new RuntimeException(e);
201             }
202 
203             if (refClass.isRecord()) {
204                 RecordComponent[] recordComponents = refClass.getRecordComponents();
205                 for (int i = 0; i < recordComponents.length; i++) {
206                     if (recordComponents[i].getName().equals(ref.name())) {
207                         return i;
208                     }
209                 }
210                 throw new InternalError();
211             }
212         }
213         return null;
214     }
215 
216     FunctionType convertType(FunctionType t) {
217         return CoreType.functionType(convertType(t.returnType()), t.parameterTypes().stream().map(this::convertType).toList());
218     }
219 
220     FunctionType convertType(CoreOp.FuncOp fo) {
221         return CoreType.functionType(convertType(fo.body().entryBlock().terminatingOp().operands().getFirst()), fo.parameters().stream().map(this::convertType).toList());
222     }
223 
224     TypeElement convertType(Value value) {
225         // convert 1-dimensional constantly accessed constant arrays into tuples
226         if (value.type() instanceof ArrayType at && at.dimensions() == 1) {
227             int size = countConstantArraySize(value.uses());
228             if (size >= 0) {
229                 var targs = new TypeElement[size];
230                 Arrays.fill(targs, convertType(at.componentType()));
231                 return CoreType.tupleType(targs);
232             }
233         }
234         return convertType(value.type());
235     }
236 
237     static int countConstantArraySize(Set<Op.Result> uses) {
238         int size = 0;
239         for (var use : uses) {
240             int s = switch (use.op()) {
241                 case JavaOp.ArrayAccessOp aao when aao.operands().get(1) instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co ->
242                     (Integer)co.value() + 1;
243                 case CoreOp.VarOp _, CoreOp.VarAccessOp.VarLoadOp _ ->
244                     countConstantArraySize(use.op().result().uses());
245                 default -> -1;
246             };
247             if (s < 0) return -1;
248             size = Integer.max(size, s);
249         }
250         return size;
251     }
252 
253     // @@@ Map of Java tensor types to ONNX tensor types
254     // @@@ Shape??
255     TypeElement convertType(TypeElement type) {
256         if (type instanceof ClassType ct) {
257             if (ct.rawType().equals(TENSOR_CLASS)) {
258                 JavaType elementType = ct.typeArguments().getFirst();
259                 if (elementType.equals(JavaType.J_L_INTEGER)) {
260                     return OnnxType.TENSOR_INT32;
261                 } else if (elementType.equals(JavaType.J_L_FLOAT)) {
262                     return OnnxType.TENSOR_FLOAT32;
263                 } else if (elementType.equals(JavaType.J_L_LONG)) {
264                     return OnnxType.TENSOR_INT64;
265                 } else if (elementType.equals(JavaType.J_L_BYTE)) {
266                     return OnnxType.TENSOR_UINT8;
267                 } else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
268                     return OnnxType.TENSOR_BOOL;
269                 }
270             } else if (isRecord(type)) {
271                 return recordTypeToTupleType(ct);
272             }
273         }
274         return type;
275     }
276 }