1 /*
2 * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.CodeTransformer;
26 import jdk.incubator.code.dialect.core.CoreOp;
27 import jdk.incubator.code.dialect.java.*;
28 import jdk.incubator.code.dialect.core.VarType;
29
30 import java.util.IdentityHashMap;
31 import java.util.List;
32 import java.util.Map;
33
34 /**
35 * Resolves unresolved types.
36 */
37 final class UnresolvedTypesTransformer {
38
39 static CoreOp.FuncOp transform(CoreOp.FuncOp func) {
40 try {
41 return new UnresolvedTypesTransformer().resolve(func);
42 } catch (Throwable t) {
43 System.out.println(func.toText());
44 throw t;
45 }
46 }
47
48 private final Map<UnresolvedType, JavaType> resolvedMap;
49
50 private UnresolvedTypesTransformer() {
51 resolvedMap = new IdentityHashMap<>();
52 }
53
54 private CoreOp.FuncOp resolve(CoreOp.FuncOp func) {
55 List<Value> unresolved = func.elements().<Value>mapMulti((e, c) -> {
56 switch (e) {
57 case Block b -> b.parameters().forEach(v -> {
58 if (toResolve(v) != null) c.accept(v);
59 });
60 case Op op when toResolve(op.result()) != null ->
61 c.accept(op.result());
62 default -> {}
63 }
64
65 }).toList();
66
67 boolean changed = true;
68 while (changed) {
69 changed = false;
70 for (Value v : unresolved) {
71 changed |= resolve(v);
72 }
73 }
74
75 // Remaining types are resolved to defaults
76 for (Value v : unresolved) {
77 resolvedMap.computeIfAbsent(toResolve(v), ut ->
78 switch (ut) {
79 case UnresolvedType.Int _ -> JavaType.INT;
80 case UnresolvedType.Ref _ -> JavaType.J_L_OBJECT;
81 });
82 }
83
84 return func.transform(blockParamTypesTransformer())
85 .transform(opTypesTransformer())
86 .transform(unifyOperandsTransformer());
87 }
88
89 private static UnresolvedType toResolve(Value v) {
90 return v == null ? null : switch (v.type()) {
91 case UnresolvedType ut -> ut;
92 case VarType vt when vt.valueType() instanceof UnresolvedType ut -> ut;
93 default -> null;
94 };
95 }
96
97 private CodeType toComponent(CodeType te) {
98 if (te instanceof UnresolvedType ut) {
99 te = resolvedMap.get(ut);
100 }
101 return te instanceof ArrayType at ? at.componentType() : null;
102 }
103
104 private CodeType toArray(CodeType te) {
105 if (te instanceof UnresolvedType ut) {
106 te = resolvedMap.get(ut);
107 }
108 return te instanceof JavaType jt ? JavaType.array(jt) : null;
109 }
110
111 private boolean resolve(Value v) {
112 UnresolvedType ut = toResolve(v);
113 if (ut == null) return false;
114 boolean changed = false;
115 for (Op.Result useRes : v.uses()) {
116 Op op = useRes.op();
117 int i = op.operands().indexOf(v);
118 if (i >= 0) {
119 changed |= switch (op) {
120 case JavaOp.LshlOp _, JavaOp.LshrOp _, JavaOp.AshrOp _ ->
121 i == 0 && resolveTo(ut, op.resultType());
122 case JavaOp.BinaryOp bo ->
123 resolveTo(ut, bo.resultType());
124 case JavaOp.InvokeOp io -> {
125 MethodRef id = io.invokeReference();
126 if (io.hasReceiver()) {
127 if (i == 0) yield resolveTo(ut, id.refType());
128 i--;
129 }
130 yield resolveTo(ut, id.signature().parameterTypes().get(i));
131 }
132 case JavaOp.FieldAccessOp fao ->
133 resolveTo(ut, fao.fieldReference().refType());
134 case CoreOp.ReturnOp ro ->
135 resolveTo(ut, ro.ancestorBody().bodySignature().returnType());
136 case CoreOp.VarOp vo ->
137 resolveTo(ut, vo.varValueType());
138 case CoreOp.VarAccessOp.VarStoreOp vso ->
139 resolveTo(ut, vso.varType().valueType());
140 case JavaOp.NewOp no ->
141 resolveTo(ut, no.constructorReference().signature().parameterTypes().get(i));
142 case JavaOp.ArrayAccessOp.ArrayLoadOp alo ->
143 resolveTo(ut, toArray(alo.resultType()));
144 case JavaOp.ArrayAccessOp.ArrayStoreOp aso ->
145 switch (i) {
146 case 0 -> resolveFrom(ut, toArray(aso.operands().get(2).type()));
147 case 2 -> resolveTo(ut, toComponent(aso.operands().get(0).type()));
148 default -> false;
149 };
150 default -> false;
151 };
152 }
153 // Pull block parameter type when used as block argument
154 for (Block.Reference sucRef : useRes.op().successors()) {
155 i = sucRef.arguments().indexOf(v);
156 if (i >= 0) {
157 changed |= resolveTo(ut, sucRef.targetBlock().parameters().get(i).type());
158 }
159 }
160 }
161 if (v instanceof Block.Parameter bp) {
162 int bi = bp.index();
163 Block b = bp.declaringBlock();
164 for (Block pb : b.predecessors()) {
165 for (Block.Reference r : pb.successors()) {
166 if (r.targetBlock() == b) {
167 var args = r.arguments();
168 if (args.size() > bi && resolveFrom(ut, args.get(bi).type())) {
169 return true;
170 }
171 }
172 }
173 }
174 } else if (v instanceof Op.Result or) {
175 changed |= switch (or.op()) {
176 case JavaOp.UnaryOp uo ->
177 resolveFrom(ut, uo.operands().getFirst().type());
178 case JavaOp.BinaryOp bo ->
179 resolveFrom(ut, bo.operands().getFirst().type())
180 || resolveFrom(ut, bo.operands().get(1).type());
181 case CoreOp.VarAccessOp.VarLoadOp vlo ->
182 resolveFrom(ut, vlo.varType().valueType());
183 case CoreOp.VarOp vo ->
184 resolveVarOpType(ut, vo);
185 case JavaOp.ArrayAccessOp.ArrayLoadOp alo ->
186 resolveFrom(ut, toComponent(alo.operands().getFirst().type()));
187 default -> false;
188 };
189 }
190 return changed;
191 }
192
193 private boolean resolveFrom(UnresolvedType unresolved, CodeType from) {
194 CodeType type = from instanceof UnresolvedType utt ? resolvedMap.get(utt) : from;
195 JavaType resolved = resolvedMap.get(unresolved);
196 return switch (unresolved) {
197 // Only care about arrays
198 case UnresolvedType.Ref _ when (resolved == null || resolved.equals(JavaType.J_L_OBJECT)) && type instanceof ArrayType at -> {
199 resolvedMap.put(unresolved, at);
200 yield true;
201 }
202 // Only care about booleans
203 case UnresolvedType.Int _ when JavaType.BOOLEAN.equals(type) && !JavaType.BOOLEAN.equals(resolved) -> {
204 resolvedMap.put(unresolved, JavaType.BOOLEAN);
205 yield true;
206 }
207 default -> false;
208 };
209 }
210
211 private static final List<PrimitiveType> INT_TYPES = List.of(JavaType.INT, JavaType.CHAR, JavaType.SHORT, JavaType.BYTE, JavaType.BOOLEAN);
212
213 private boolean resolveTo(UnresolvedType unresolved, CodeType to) {
214 CodeType type = to instanceof UnresolvedType utt ? resolvedMap.get(utt) : to;
215 JavaType resolved = resolvedMap.get(unresolved);
216 return switch (unresolved) {
217 case UnresolvedType.Ref _ when (resolved == null || resolved.equals(JavaType.J_L_OBJECT)) && type instanceof JavaType jt && !jt.equals(resolved) -> {
218 resolvedMap.put(unresolved, jt);
219 yield true;
220 }
221 case UnresolvedType.Int _ when type instanceof PrimitiveType pt && (INT_TYPES.indexOf(pt) > (resolved == null ? -1 : INT_TYPES.indexOf(resolved))) -> {
222 resolvedMap.put(unresolved, pt);
223 yield true;
224 }
225 default -> false;
226 };
227 }
228
229 private boolean resolveVarOpType(UnresolvedType ut, CoreOp.VarOp vo) {
230 boolean changed = vo.isUninitialized() ? false : resolveFrom(ut, vo.initOperand().type());
231 for (Op.Result varUses : vo.result().uses()) {
232 changed |= switch (varUses.op()) {
233 case CoreOp.VarAccessOp.VarLoadOp vlo ->
234 resolveTo(ut, vlo.resultType());
235 case CoreOp.VarAccessOp.VarStoreOp vso ->
236 resolveFrom(ut, vso.storeOperand().type());
237 default -> false;
238 };
239 }
240 return changed;
241 }
242
243 private Object convertValue(UnresolvedType ut, Object value) {
244 return switch (INT_TYPES.indexOf(resolvedMap.get(ut))) {
245 case 0 -> toNumber(value).intValue();
246 case 1 -> (char)toNumber(value).intValue();
247 case 2 -> toNumber(value).shortValue();
248 case 3 -> toNumber(value).byteValue();
249 case 4 -> value instanceof Number n ? n.intValue() != 0 : (Boolean)value;
250 default -> value;
251 };
252 }
253
254 private static Number toNumber(Object value) {
255 return switch (value) {
256 case Boolean b -> b ? 1 : 0;
257 case Character c -> (int)c;
258 case Number n -> n;
259 default -> throw new IllegalStateException("Unexpected " + value);
260 };
261 }
262
263 private CodeTransformer blockParamTypesTransformer() {
264 return new CodeTransformer() {
265 @Override
266 public void acceptBlock(Block.Builder block, Block b) {
267 if (block.isEntryBlock()) {
268 CodeContext cc = block.context();
269 List<Block> sourceBlocks = b.ancestorBody().blocks();
270
271 // Override blocks with changed parameter types
272 for (int i = 1; i < sourceBlocks.size(); i++) {
273 Block sourceBlock = sourceBlocks.get(i);
274 List<CodeType> paramTypes = sourceBlock.parameterTypes();
275 if (paramTypes.stream().anyMatch(UnresolvedType.class::isInstance)) {
276 Block.Builder newBlock = block.block(paramTypes.stream()
277 .map(pt -> pt instanceof UnresolvedType ut ? resolvedMap.get(ut) : pt)
278 .toList());
279 cc.mapBlock(sourceBlock, newBlock);
280 cc.mapValues(sourceBlock.parameters(), newBlock.parameters());
281 }
282 }
283
284 }
285 CodeTransformer.super.acceptBlock(block, b);
286 }
287
288 @Override
289 public Block.Builder acceptOp(Block.Builder block, Op op) {
290 block.add(op);
291 return block;
292 }
293 };
294 }
295
296 private CodeTransformer opTypesTransformer() {
297 return (block, op) -> {
298 CodeContext cc = block.context();
299 switch (op) {
300 case CoreOp.ConstantOp cop when op.resultType() instanceof UnresolvedType ut ->
301 cc.mapValue(op.result(), block.add(CoreOp.constant(resolvedMap.get(ut), convertValue(ut, cop.value()))));
302 case CoreOp.VarOp vop when vop.varValueType() instanceof UnresolvedType ut ->
303 cc.mapValue(op.result(), block.add(vop.isUninitialized()
304 ? CoreOp.var(vop.varName(), resolvedMap.get(ut))
305 : CoreOp.var(vop.varName(), resolvedMap.get(ut), cc.queryValue(vop.initOperand()).orElse(vop.initOperand()))));
306 case JavaOp.ArrayAccessOp.ArrayLoadOp alop when op.resultType() instanceof UnresolvedType -> {
307 List<Value> opers = alop.operands();
308 Value array = opers.getFirst();
309 Value index = opers.getLast();
310 cc.mapValue(op.result(), block.add(JavaOp.arrayLoadOp(
311 cc.queryValue(array).orElse(array),
312 cc.queryValue(index).orElse(index))));
313 }
314 default ->
315 block.add(op);
316 }
317 return block;
318 };
319 }
320
321 private static CodeTransformer unifyOperandsTransformer() {
322 return (block, op) -> {
323 switch (op) {
324 case JavaOp.CompareOp _ ->
325 unify(block, op, JavaType.INT, JavaType.INT);
326 case JavaOp.LshlOp _, JavaOp.LshrOp _, JavaOp.AshrOp _ ->
327 unify(block, op, op.resultType(), JavaType.INT);
328 case JavaOp.BinaryOp _ ->
329 unify(block, op, op.resultType(), op.resultType());
330 default ->
331 block.add(op);
332 }
333 return block;
334 };
335 }
336
337 private static void unify(Block.Builder block, Op op, CodeType firstType, CodeType secondType) {
338 List<Value> operands = op.operands();
339 CodeContext cc = CodeContext.create(block.context());
340 Value first = operands.getFirst();
341 boolean changed = false;
342 if (first.type() instanceof PrimitiveType && !first.type().equals(firstType)) {
343 cc.mapValue(first, block.add(JavaOp.conv(firstType, cc.queryValue(first).orElse(first))));
344 changed = true;
345 }
346 Value second = operands.get(1);
347 if (second.type() instanceof PrimitiveType && !second.type().equals(secondType)) {
348 cc.mapValue(second, block.add(JavaOp.conv(secondType, cc.queryValue(second).orElse(second))));
349 changed = true;
350 }
351 if (changed) {
352 block.context().mapValue(op.result(), block.add(op.transform(cc, CodeTransformer.COPYING_TRANSFORMER)));
353 } else {
354 block.add(op);
355 }
356 }
357 }