1 /*
  2  * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.CodeTransformer;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.java.*;
 28 import jdk.incubator.code.dialect.core.VarType;
 29 
 30 import java.util.IdentityHashMap;
 31 import java.util.List;
 32 import java.util.Map;
 33 
 34 /**
 35  * Resolves unresolved types.
 36  */
 37 final class UnresolvedTypesTransformer {
 38 
 39     static CoreOp.FuncOp transform(CoreOp.FuncOp func) {
 40         try {
 41             return new UnresolvedTypesTransformer().resolve(func);
 42         } catch (Throwable t) {
 43             System.out.println(func.toText());
 44             throw t;
 45         }
 46     }
 47 
 48     private final Map<UnresolvedType, JavaType> resolvedMap;
 49 
 50     private UnresolvedTypesTransformer() {
 51         resolvedMap = new IdentityHashMap<>();
 52     }
 53 
 54     private CoreOp.FuncOp resolve(CoreOp.FuncOp func) {
 55         List<Value> unresolved = func.elements().<Value>mapMulti((e, c) -> {
 56             switch (e) {
 57                 case Block b -> b.parameters().forEach(v -> {
 58                     if (toResolve(v) != null) c.accept(v);
 59                 });
 60                 case Op op when toResolve(op.result()) != null ->
 61                         c.accept(op.result());
 62                 default -> {}
 63             }
 64 
 65         }).toList();
 66 
 67         boolean changed = true;
 68         while (changed) {
 69             changed = false;
 70             for (Value v : unresolved) {
 71                 changed |= resolve(v);
 72             }
 73         }
 74 
 75         // Remaining types are resolved to defaults
 76         for (Value v : unresolved) {
 77             resolvedMap.computeIfAbsent(toResolve(v), ut ->
 78                 switch (ut) {
 79                     case UnresolvedType.Int _ -> JavaType.INT;
 80                     case UnresolvedType.Ref _ -> JavaType.J_L_OBJECT;
 81                 });
 82         }
 83 
 84         return func.transform(blockParamTypesTransformer())
 85                    .transform(opTypesTransformer())
 86                    .transform(unifyOperandsTransformer());
 87     }
 88 
 89     private static UnresolvedType toResolve(Value v) {
 90         return v == null ? null : switch (v.type()) {
 91             case UnresolvedType ut -> ut;
 92             case VarType vt when vt.valueType() instanceof UnresolvedType ut  -> ut;
 93             default -> null;
 94         };
 95     }
 96 
 97     private CodeType toComponent(CodeType te) {
 98         if (te instanceof UnresolvedType ut) {
 99             te = resolvedMap.get(ut);
100         }
101         return te instanceof ArrayType at ? at.componentType() : null;
102     }
103 
104     private CodeType toArray(CodeType te) {
105         if (te instanceof UnresolvedType ut) {
106             te = resolvedMap.get(ut);
107         }
108         return te instanceof JavaType jt ? JavaType.array(jt) : null;
109     }
110 
111     private boolean resolve(Value v) {
112         UnresolvedType ut = toResolve(v);
113         if (ut == null) return false;
114         boolean changed = false;
115         for (Op.Result useRes : v.uses()) {
116             Op op = useRes.op();
117             int i = op.operands().indexOf(v);
118             if (i >= 0) {
119                 changed |= switch (op) {
120                     case JavaOp.LshlOp _, JavaOp.LshrOp _, JavaOp.AshrOp _ ->
121                         i == 0 && resolveTo(ut, op.resultType());
122                     case JavaOp.BinaryOp bo ->
123                         resolveTo(ut, bo.resultType());
124                     case JavaOp.InvokeOp io -> {
125                         MethodRef id = io.invokeReference();
126                         if (io.hasReceiver()) {
127                             if (i == 0) yield resolveTo(ut, id.refType());
128                             i--;
129                         }
130                         yield resolveTo(ut, id.signature().parameterTypes().get(i));
131                     }
132                     case JavaOp.FieldAccessOp fao ->
133                         resolveTo(ut, fao.fieldReference().refType());
134                     case CoreOp.ReturnOp ro ->
135                         resolveTo(ut, ro.ancestorBody().bodySignature().returnType());
136                     case CoreOp.VarOp vo ->
137                         resolveTo(ut, vo.varValueType());
138                     case CoreOp.VarAccessOp.VarStoreOp vso ->
139                         resolveTo(ut, vso.varType().valueType());
140                     case JavaOp.NewOp no ->
141                         resolveTo(ut, no.constructorReference().signature().parameterTypes().get(i));
142                     case JavaOp.ArrayAccessOp.ArrayLoadOp alo ->
143                         resolveTo(ut, toArray(alo.resultType()));
144                     case JavaOp.ArrayAccessOp.ArrayStoreOp aso ->
145                         switch (i) {
146                             case 0 -> resolveFrom(ut, toArray(aso.operands().get(2).type()));
147                             case 2 -> resolveTo(ut, toComponent(aso.operands().get(0).type()));
148                             default -> false;
149                         };
150                     default -> false;
151                 };
152             }
153             // Pull block parameter type when used as block argument
154             for (Block.Reference sucRef : useRes.op().successors()) {
155                 i = sucRef.arguments().indexOf(v);
156                 if (i >= 0) {
157                     changed |= resolveTo(ut, sucRef.targetBlock().parameters().get(i).type());
158                 }
159             }
160         }
161         if (v instanceof Block.Parameter bp) {
162             int bi = bp.index();
163             Block b = bp.declaringBlock();
164             for (Block pb : b.predecessors()) {
165                 for (Block.Reference r : pb.successors()) {
166                     if (r.targetBlock() == b) {
167                         var args = r.arguments();
168                         if (args.size() > bi && resolveFrom(ut, args.get(bi).type())) {
169                             return true;
170                         }
171                     }
172                 }
173             }
174         } else if (v instanceof Op.Result or) {
175             changed |= switch (or.op()) {
176                 case JavaOp.UnaryOp uo ->
177                     resolveFrom(ut, uo.operands().getFirst().type());
178                 case JavaOp.BinaryOp bo ->
179                     resolveFrom(ut, bo.operands().getFirst().type())
180                     || resolveFrom(ut, bo.operands().get(1).type());
181                 case CoreOp.VarAccessOp.VarLoadOp vlo ->
182                     resolveFrom(ut, vlo.varType().valueType());
183                 case CoreOp.VarOp vo ->
184                     resolveVarOpType(ut, vo);
185                 case JavaOp.ArrayAccessOp.ArrayLoadOp alo ->
186                     resolveFrom(ut, toComponent(alo.operands().getFirst().type()));
187                 default -> false;
188             };
189         }
190         return changed;
191     }
192 
193     private boolean resolveFrom(UnresolvedType unresolved, CodeType from) {
194         CodeType type = from instanceof UnresolvedType utt ? resolvedMap.get(utt) : from;
195         JavaType resolved = resolvedMap.get(unresolved);
196         return switch (unresolved) {
197             // Only care about arrays
198             case UnresolvedType.Ref _ when (resolved == null || resolved.equals(JavaType.J_L_OBJECT)) && type instanceof ArrayType at -> {
199                 resolvedMap.put(unresolved, at);
200                 yield true;
201             }
202             // Only care about booleans
203             case UnresolvedType.Int _ when JavaType.BOOLEAN.equals(type) && !JavaType.BOOLEAN.equals(resolved) -> {
204                 resolvedMap.put(unresolved, JavaType.BOOLEAN);
205                 yield true;
206             }
207             default -> false;
208         };
209     }
210 
211     private static final List<PrimitiveType> INT_TYPES = List.of(JavaType.INT, JavaType.CHAR, JavaType.SHORT, JavaType.BYTE, JavaType.BOOLEAN);
212 
213     private boolean resolveTo(UnresolvedType unresolved, CodeType to) {
214         CodeType type = to instanceof UnresolvedType utt ? resolvedMap.get(utt) : to;
215         JavaType resolved = resolvedMap.get(unresolved);
216         return switch (unresolved) {
217             case UnresolvedType.Ref _ when (resolved == null || resolved.equals(JavaType.J_L_OBJECT)) && type instanceof JavaType jt && !jt.equals(resolved) -> {
218                 resolvedMap.put(unresolved, jt);
219                 yield true;
220             }
221             case UnresolvedType.Int _ when type instanceof PrimitiveType pt && (INT_TYPES.indexOf(pt) > (resolved == null ? -1 : INT_TYPES.indexOf(resolved))) -> {
222                 resolvedMap.put(unresolved, pt);
223                 yield true;
224             }
225             default -> false;
226         };
227     }
228 
229     private boolean resolveVarOpType(UnresolvedType ut, CoreOp.VarOp vo) {
230         boolean changed = vo.isUninitialized() ? false : resolveFrom(ut, vo.initOperand().type());
231         for (Op.Result varUses : vo.result().uses()) {
232             changed |= switch (varUses.op()) {
233                 case CoreOp.VarAccessOp.VarLoadOp vlo ->
234                     resolveTo(ut, vlo.resultType());
235                 case CoreOp.VarAccessOp.VarStoreOp vso ->
236                     resolveFrom(ut, vso.storeOperand().type());
237                 default -> false;
238             };
239         }
240         return changed;
241     }
242 
243     private Object convertValue(UnresolvedType ut, Object value) {
244         return switch (INT_TYPES.indexOf(resolvedMap.get(ut))) {
245             case 0 -> toNumber(value).intValue();
246             case 1 -> (char)toNumber(value).intValue();
247             case 2 -> toNumber(value).shortValue();
248             case 3 -> toNumber(value).byteValue();
249             case 4 -> value instanceof Number n ? n.intValue() != 0 : (Boolean)value;
250             default -> value;
251         };
252     }
253 
254     private static Number toNumber(Object value) {
255         return switch (value) {
256             case Boolean b -> b ? 1 : 0;
257             case Character c -> (int)c;
258             case Number n -> n;
259             default -> throw new IllegalStateException("Unexpected " + value);
260         };
261     }
262 
263     private CodeTransformer blockParamTypesTransformer() {
264         return new CodeTransformer() {
265             @Override
266             public void acceptBlock(Block.Builder block, Block b) {
267                 if (block.isEntryBlock()) {
268                     CodeContext cc = block.context();
269                     List<Block> sourceBlocks = b.ancestorBody().blocks();
270 
271                     // Override blocks with changed parameter types
272                     for (int i = 1; i < sourceBlocks.size(); i++) {
273                         Block sourceBlock = sourceBlocks.get(i);
274                         List<CodeType> paramTypes = sourceBlock.parameterTypes();
275                         if (paramTypes.stream().anyMatch(UnresolvedType.class::isInstance)) {
276                             Block.Builder newBlock = block.block(paramTypes.stream()
277                                     .map(pt -> pt instanceof UnresolvedType ut  ? resolvedMap.get(ut) : pt)
278                                     .toList());
279                             cc.mapBlock(sourceBlock, newBlock);
280                             cc.mapValues(sourceBlock.parameters(), newBlock.parameters());
281                         }
282                     }
283 
284                 }
285                 CodeTransformer.super.acceptBlock(block, b);
286             }
287 
288             @Override
289             public Block.Builder acceptOp(Block.Builder block, Op op) {
290                 block.add(op);
291                 return block;
292             }
293         };
294     }
295 
296     private CodeTransformer opTypesTransformer() {
297         return (block, op) -> {
298             CodeContext cc = block.context();
299             switch (op) {
300                 case CoreOp.ConstantOp cop when op.resultType() instanceof UnresolvedType ut ->
301                     cc.mapValue(op.result(), block.add(CoreOp.constant(resolvedMap.get(ut), convertValue(ut, cop.value()))));
302                 case CoreOp.VarOp vop when vop.varValueType() instanceof UnresolvedType ut ->
303                     cc.mapValue(op.result(), block.add(vop.isUninitialized()
304                             ? CoreOp.var(vop.varName(), resolvedMap.get(ut))
305                             : CoreOp.var(vop.varName(), resolvedMap.get(ut), cc.queryValue(vop.initOperand()).orElse(vop.initOperand()))));
306                 case JavaOp.ArrayAccessOp.ArrayLoadOp alop when op.resultType() instanceof UnresolvedType -> {
307                     List<Value> opers = alop.operands();
308                     Value array = opers.getFirst();
309                     Value index = opers.getLast();
310                     cc.mapValue(op.result(), block.add(JavaOp.arrayLoadOp(
311                             cc.queryValue(array).orElse(array),
312                             cc.queryValue(index).orElse(index))));
313                 }
314                 default ->
315                     block.add(op);
316             }
317             return block;
318         };
319     }
320 
321     private static CodeTransformer unifyOperandsTransformer() {
322         return (block, op) -> {
323             switch (op) {
324                 case JavaOp.CompareOp _ ->
325                     unify(block, op, JavaType.INT, JavaType.INT);
326                 case JavaOp.LshlOp _, JavaOp.LshrOp _, JavaOp.AshrOp _ ->
327                     unify(block, op, op.resultType(), JavaType.INT);
328                 case JavaOp.BinaryOp _ ->
329                     unify(block, op, op.resultType(), op.resultType());
330                 default ->
331                     block.add(op);
332             }
333             return block;
334         };
335     }
336 
337     private static void unify(Block.Builder block, Op op, CodeType firstType, CodeType secondType) {
338         List<Value> operands = op.operands();
339         CodeContext cc = CodeContext.create(block.context());
340         Value first = operands.getFirst();
341         boolean changed = false;
342         if (first.type() instanceof PrimitiveType && !first.type().equals(firstType)) {
343             cc.mapValue(first, block.add(JavaOp.conv(firstType, cc.queryValue(first).orElse(first))));
344             changed = true;
345         }
346         Value second = operands.get(1);
347         if (second.type() instanceof PrimitiveType && !second.type().equals(secondType)) {
348             cc.mapValue(second, block.add(JavaOp.conv(secondType, cc.queryValue(second).orElse(second))));
349             changed = true;
350         }
351         if (changed) {
352             block.context().mapValue(op.result(), block.add(op.transform(cc, CodeTransformer.COPYING_TRANSFORMER)));
353         } else {
354             block.add(op);
355         }
356     }
357 }