1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.dialect.core.CoreOp;
26 import jdk.incubator.code.dialect.core.FunctionType;
27 import jdk.incubator.code.dialect.java.*;
28
29 import java.lang.invoke.MethodHandle;
30 import java.lang.invoke.MethodHandles;
31 import java.lang.invoke.MethodType;
32 import java.lang.invoke.VarHandle;
33 import java.lang.reflect.Array;
34 import java.util.*;
35 import java.util.function.Predicate;
36 import java.util.stream.Collectors;
37 import java.util.stream.Stream;
38
39 final class PartialEvaluator {
40 final Set<Value> constants;
41 final Predicate<Op> opConstant;
42
43 PartialEvaluator(Set<Value> constants, Predicate<Op> opConstant) {
44 this.constants = new LinkedHashSet<>(constants);
45 this.opConstant = opConstant;
46 }
47
48 public static
49 CoreOp.FuncOp evaluate(MethodHandles.Lookup l,
50 Predicate<Op> opConstant, Set<Value> constants,
51 CoreOp.FuncOp op) {
52 PartialEvaluator pe = new PartialEvaluator(constants, opConstant);
53 Body.Builder outBody = pe.evaluateBody(l, op.body());
54 return CoreOp.func(op.funcName(), outBody);
55 }
56
57
58 @SuppressWarnings("serial")
59 public static final class EvaluationException extends RuntimeException {
60 private EvaluationException(Throwable cause) {
61 super(cause);
62 }
63 }
64
65 static EvaluationException evaluationException(Throwable cause) {
66 return new EvaluationException(cause);
67 }
68
69 static final class BodyContext {
70 final BodyContext parent;
71
72 final Map<Block, List<Block>> evaluatedPredecessors;
73 final Map<Value, Object> evaluatedValues;
74
75 final Queue<Block> blockStack;
76 final BitSet visited;
77
78 BodyContext(Block entryBlock) {
79 this.parent = null;
80
81 this.evaluatedPredecessors = new HashMap<>();
82 this.evaluatedValues = new HashMap<>();
83 this.blockStack = new PriorityQueue<>(Comparator.comparingInt(Block::index));
84
85 this.visited = new BitSet();
86 }
87
88 Object getValue(Value v) {
89 Object rv = evaluatedValues.get(v);
90 if (rv != null) {
91 return rv;
92 }
93
94 throw evaluationException(new IllegalArgumentException("Undefined value: " + v));
95 }
96
97 void setValue(Value v, Object o) {
98 evaluatedValues.put(v, o);
99 }
100 }
101
102 Body.Builder evaluateBody(MethodHandles.Lookup l,
103 Body inBody) {
104 Block inEntryBlock = inBody.entryBlock();
105
106 Body.Builder outBody = Body.Builder.of(null, inBody.bodyType());
107 Block.Builder outEntryBlock = outBody.entryBlock();
108
109 CopyContext cc = outEntryBlock.context();
110 cc.mapBlock(inEntryBlock, outEntryBlock);
111 cc.mapValues(inEntryBlock.parameters(), outEntryBlock.parameters());
112
113 evaluateEntryBlock(l, inEntryBlock, outEntryBlock, new BodyContext(inEntryBlock));
114
115 return outBody;
116 }
117
118 void evaluateEntryBlock(MethodHandles.Lookup l,
119 Block inEntryBlock,
120 Block.Builder outEntryBlock,
121 BodyContext bc) {
122 assert inEntryBlock.isEntryBlock();
123
124 Map<Block, LoopAnalyzer.Loop> loops = new HashMap<>();
125 Set<Block> loopNoPeeling = new HashSet<>();
126
127 // The first block cannot have any successors so the queue will have at least one entry
128 bc.blockStack.add(inEntryBlock);
129 while (!bc.blockStack.isEmpty()) {
130 final Block inBlock = bc.blockStack.poll();
131 if (bc.visited.get(inBlock.index())) {
132 continue;
133 }
134 bc.visited.set(inBlock.index());
135
136 final Block.Builder outBlock = outEntryBlock.context().getBlock(inBlock);
137
138 nopeel: if (inBlock.predecessors().size() > 1 && bc.evaluatedPredecessors.get(inBlock).size() == 1) {
139 // If we reached to this block through just one evaluated predecessor
140 Block inBlockPred = bc.evaluatedPredecessors.get(inBlock).getFirst();
141 Block.Reference inBlockRef = inBlockPred.terminatingOp().successors().stream()
142 .filter(r -> r.targetBlock() == inBlock)
143 .findFirst().get();
144 List<Value> args = inBlockRef.arguments();
145 List<Boolean> argConstant = args.stream().map(constants::contains).toList();
146
147 LoopAnalyzer.Loop loop = loops.computeIfAbsent(inBlock, b -> LoopAnalyzer.isLoop(inBlock).orElse(null));
148 if (loop != null && inBlockPred.isDominatedBy(loop.header())) {
149 // Entering loop header from latch
150 assert loop.latches().contains(inBlockPred);
151
152 // Linear constant path from each exiting block (or nearest evaluated present dominator) to loop header
153 boolean constantExits = true;
154 for (LoopAnalyzer.LoopExit loopExitPair : loop.exits()) {
155 Block loopExit = loopExitPair.exit();
156
157 // Find nearest evaluated dominator
158 List<Block> ePreds = bc.evaluatedPredecessors.get(loopExit);
159 while (ePreds == null) {
160 loopExit = loopExit.immediateDominator();
161 ePreds = bc.evaluatedPredecessors.get(loopExit);
162 }
163 assert loop.body().contains(loopExit);
164
165 if (ePreds.size() != 1 ||
166 !(loopExit.terminatingOp() instanceof CoreOp.ConditionalBranchOp cbr) ||
167 !constants.contains(cbr.result())) {
168 // If there are multiple encounters, or terminal op is not a constant conditional branch
169 constantExits = false;
170 break;
171 }
172 }
173
174 // Determine if constant args, before reset
175 boolean constantArgs = constants.containsAll(args);
176
177 // Reset state within loop body
178 for (Block block : loop.body()) {
179 // Reset visits, but not for loop header
180 if (block != loop.header()) {
181 bc.evaluatedPredecessors.remove(block);
182 bc.visited.set(block.index(), false);
183 }
184
185 // Reset constants
186 for (Op op : block.ops()) {
187 constants.remove(op.result());
188 }
189 constants.removeAll(block.parameters());
190
191 // Reset no peeling for any nested loops
192 loopNoPeeling.remove(block);
193 }
194
195 if (!constantExits || !constantArgs) {
196 // Finish peeling
197 // No constant exit and no constant args
198 loopNoPeeling.addAll(loop.latches());
199 break nopeel;
200 }
201 // Peel next iteration
202 }
203
204 // Propagate constant arguments
205 for (int i = 0; i < args.size(); i++) {
206 Value inArgument = args.get(i);
207 if (argConstant.get(i)) {
208 Block.Parameter inParameter = inBlock.parameters().get(i);
209
210 // Map input parameter to output argument
211 outBlock.context().mapValue(inParameter, outBlock.context().getValue(inArgument));
212 // Set parameter constant
213 constants.add(inParameter);
214 bc.setValue(inParameter, bc.getValue(inArgument));
215 }
216 }
217 }
218
219 // Process all but the terminating operation
220 int nops = inBlock.ops().size();
221 for (int i = 0; i < nops - 1; i++) {
222 Op op = inBlock.ops().get(i);
223
224 if (isConstant(op)) {
225 // Evaluate operation
226 // @@@ Handle exceptions
227 Object result = interpretOp(l, bc, op);
228 bc.setValue(op.result(), result);
229
230 if (op instanceof CoreOp.VarOp) {
231 // @@@ Do not turn into constant to avoid conflicts with the interpreter
232 // and its runtime representation of vars
233 outBlock.op(op);
234 } else {
235 // Result was evaluated, replace with constant operation
236 Op.Result constantResult = outBlock.op(CoreOp.constant(op.resultType(), result));
237 outBlock.context().mapValue(op.result(), constantResult);
238 }
239 } else {
240 // Copy unevaluated operation
241 Op.Result r = outBlock.op(op);
242 // Explicitly remap result, since the op can be copied more than once in pealed loops
243 // @@@ See comment Block.op code which implicitly limits this
244 outBlock.context().mapValue(op.result(), r);
245 }
246 }
247
248 // Process the terminating operation
249 Op to = inBlock.terminatingOp();
250 switch (to) {
251 case CoreOp.ConditionalBranchOp cb -> {
252 if (isConstant(to)) {
253 boolean p = switch (bc.getValue(cb.predicate())) {
254 case Boolean bp -> bp;
255 case Integer ip ->
256 // @@@ This is required when lifting up from bytecode, since boolean values
257 // are erased to int values, abd the bytecode lifting implementation is not currently
258 // sophisticated enough to recover the type information
259 ip != 0;
260 default -> throw evaluationException(
261 new UnsupportedOperationException("Unsupported type input to operation: " + cb));
262 };
263
264 Block.Reference nextInBlockRef = p ? cb.trueBranch() : cb.falseBranch();
265 Block nextInBlock = nextInBlockRef.targetBlock();
266
267 // @@@ might be latch to loop
268 assert !inBlock.isDominatedBy(nextInBlock);
269
270 processBlock(bc, inBlock, nextInBlock, outBlock);
271
272 outBlock.op(CoreOp.branch(outBlock.context().getSuccessorOrCreate(nextInBlockRef)));
273 } else {
274 // @@@ might be non-constant latch to loop
275 processBlock(bc, inBlock, cb.falseBranch().targetBlock(), outBlock);
276 processBlock(bc, inBlock, cb.trueBranch().targetBlock(), outBlock);
277
278 outBlock.op(to);
279 }
280 }
281 case CoreOp.BranchOp b -> {
282 Block.Reference nextInBlockRef = b.branch();
283 Block nextInBlock = nextInBlockRef.targetBlock();
284
285 if (inBlock.isDominatedBy(nextInBlock)) {
286 // latch to loop header
287 assert bc.visited.get(nextInBlock.index());
288 if (!loopNoPeeling.contains(inBlock) && constants.containsAll(nextInBlock.parameters())) {
289 // Reset loop body to peel off another iteration
290 bc.visited.set(nextInBlock.index(), false);
291 bc.evaluatedPredecessors.remove(nextInBlock);
292 }
293 }
294
295 processBlock(bc, inBlock, nextInBlock, outBlock);
296
297 outBlock.op(b);
298 }
299 case CoreOp.ReturnOp _ -> outBlock.op(to);
300 default -> throw evaluationException(
301 new UnsupportedOperationException("Unsupported terminating operation: " + to));
302 }
303 }
304 }
305
306 boolean isConstant(Op op) {
307 if (constants.contains(op.result())) {
308 return true;
309 } else if (constants.containsAll(op.operands()) && opConstant.test(op)) {
310 constants.add(op.result());
311 return true;
312 } else {
313 return false;
314 }
315 }
316
317 void processBlock(BodyContext bc, Block inBlock, Block nextInBlock, Block.Builder outBlock) {
318 bc.blockStack.add(nextInBlock);
319 if (!bc.evaluatedPredecessors.containsKey(nextInBlock)) {
320 // Copy block
321 Block.Builder nextOutBlock = outBlock.block(nextInBlock.parameterTypes());
322 outBlock.context().mapBlock(nextInBlock, nextOutBlock);
323 outBlock.context().mapValues(nextInBlock.parameters(), nextOutBlock.parameters());
324 }
325 bc.evaluatedPredecessors.computeIfAbsent(nextInBlock, _ -> new ArrayList<>()).add(inBlock);
326 }
327
328 @SuppressWarnings("unchecked")
329 public static <E extends Throwable> void eraseAndThrow(Throwable e) throws E {
330 throw (E) e;
331 }
332
333 // @@@ This could be shared with the interpreter if it was more extensible
334 Object interpretOp(MethodHandles.Lookup l, BodyContext bc, Op o) {
335 switch (o) {
336 case CoreOp.ConstantOp co -> {
337 if (co.resultType().equals(JavaType.J_L_CLASS)) {
338 return resolveToClass(l, (JavaType) co.value());
339 } else {
340 return co.value();
341 }
342 }
343 case JavaOp.InvokeOp co -> {
344 MethodType target = resolveToMethodType(l, o.opType());
345 MethodHandles.Lookup il = switch (co.invokeKind()) {
346 case STATIC, INSTANCE -> l;
347 case SUPER -> l.in(target.parameterType(0));
348 };
349 MethodHandle mh = resolveToMethodHandle(il, co.invokeDescriptor(), co.invokeKind());
350
351 mh = mh.asType(target).asFixedArity();
352 Object[] values = o.operands().stream().map(bc::getValue).toArray();
353 return invoke(mh, values);
354 }
355 case JavaOp.NewOp no -> {
356 Object[] values = o.operands().stream().map(bc::getValue).toArray();
357 JavaType nType = (JavaType) no.resultType();
358 if (nType instanceof ArrayType at) {
359 if (values.length > at.dimensions()) {
360 throw evaluationException(new IllegalArgumentException("Bad constructor NewOp: " + no));
361 }
362 int[] lengths = Stream.of(values).mapToInt(v -> (int) v).toArray();
363 for (int length : lengths) {
364 nType = ((ArrayType) nType).componentType();
365 }
366 return Array.newInstance(resolveToClass(l, nType), lengths);
367 } else {
368 MethodHandle mh = constructorHandle(l, no.constructorDescriptor().type());
369 return invoke(mh, values);
370 }
371 }
372 case CoreOp.VarOp vo -> {
373 Object[] vbox = vo.isUninitialized()
374 ? new Object[] { null, false }
375 : new Object[] { bc.getValue(o.operands().get(0)) };
376 return vbox;
377 }
378 case CoreOp.VarAccessOp.VarLoadOp vlo -> {
379 // Cast to CoreOp.Var, since the instance may have originated as an external instance
380 // via a captured value map
381 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0));
382 if (vbox.length == 2 && !((Boolean) vbox[1])) {
383 throw evaluationException(new IllegalStateException("Loading from uninitialized variable"));
384 }
385 return vbox[0];
386 }
387 case CoreOp.VarAccessOp.VarStoreOp vso -> {
388 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0));
389 if (vbox.length == 2) {
390 vbox[1] = true;
391 }
392 vbox[0] = bc.getValue(o.operands().get(1));
393 return null;
394 }
395 case CoreOp.TupleOp to -> {
396 return o.operands().stream().map(bc::getValue).toList();
397 }
398 case CoreOp.TupleLoadOp tlo -> {
399 @SuppressWarnings("unchecked")
400 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0));
401 return tb.get(tlo.index());
402 }
403 case CoreOp.TupleWithOp two -> {
404 @SuppressWarnings("unchecked")
405 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0));
406 List<Object> copy = new ArrayList<>(tb);
407 copy.set(two.index(), bc.getValue(o.operands().get(1)));
408 return Collections.unmodifiableList(copy);
409 }
410 case JavaOp.FieldAccessOp.FieldLoadOp fo -> {
411 if (fo.operands().isEmpty()) {
412 VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor());
413 return vh.get();
414 } else {
415 Object v = bc.getValue(o.operands().get(0));
416 VarHandle vh = fieldHandle(l, fo.fieldDescriptor());
417 return vh.get(v);
418 }
419 }
420 case JavaOp.FieldAccessOp.FieldStoreOp fo -> {
421 if (fo.operands().size() == 1) {
422 Object v = bc.getValue(o.operands().get(0));
423 VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor());
424 vh.set(v);
425 } else {
426 Object r = bc.getValue(o.operands().get(0));
427 Object v = bc.getValue(o.operands().get(1));
428 VarHandle vh = fieldHandle(l, fo.fieldDescriptor());
429 vh.set(r, v);
430 }
431 return null;
432 }
433 case JavaOp.InstanceOfOp io -> {
434 Object v = bc.getValue(o.operands().get(0));
435 return isInstance(l, io.type(), v);
436 }
437 case JavaOp.CastOp co -> {
438 Object v = bc.getValue(o.operands().get(0));
439 return cast(l, co.type(), v);
440 }
441 case JavaOp.ArrayLengthOp arrayLengthOp -> {
442 Object a = bc.getValue(o.operands().get(0));
443 return Array.getLength(a);
444 }
445 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
446 Object a = bc.getValue(o.operands().get(0));
447 Object index = bc.getValue(o.operands().get(1));
448 return Array.get(a, (int) index);
449 }
450 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
451 Object a = bc.getValue(o.operands().get(0));
452 Object index = bc.getValue(o.operands().get(1));
453 Object v = bc.getValue(o.operands().get(2));
454 Array.set(a, (int) index, v);
455 return null;
456 }
457 case JavaOp.ArithmeticOperation arithmeticOperation -> {
458 // @@@ TODO avoid use of opName
459 MethodHandle mh = opHandle(l, o.externalizeOpName(), o.opType());
460 Object[] values = o.operands().stream().map(bc::getValue).toArray();
461 return invoke(mh, values);
462 }
463 case JavaOp.TestOperation testOperation -> {
464 // @@@ TODO avoid use of opName
465 MethodHandle mh = opHandle(l, o.externalizeOpName(), o.opType());
466 Object[] values = o.operands().stream().map(bc::getValue).toArray();
467 return invoke(mh, values);
468 }
469 case JavaOp.ConvOp convOp -> {
470 // @@@ TODO avoid use of opName
471 MethodHandle mh = opHandle(l, o.externalizeOpName() + "_" + o.opType().returnType(), o.opType());
472 Object[] values = o.operands().stream().map(bc::getValue).toArray();
473 return invoke(mh, values);
474 }
475 case JavaOp.ConcatOp concatOp -> {
476 return o.operands().stream()
477 .map(bc::getValue)
478 .map(String::valueOf)
479 .collect(Collectors.joining());
480 }
481 // @@@
482 // case CoreOp.LambdaOp lambdaOp -> {
483 // interpretEntryBlock(l, lambdaOp.body().entryBlock(), oc, new HashMap<>());
484 // unevaluatedOperations.add(o);
485 // return null;
486 // }
487 // case CoreOp.FuncOp funcOp -> {
488 // interpretEntryBlock(l, funcOp.body().entryBlock(), oc, new HashMap<>());
489 // unevaluatedOperations.add(o);
490 // return null;
491 // }
492 case null, default -> throw evaluationException(
493 new UnsupportedOperationException("Unsupported operation: " + o));
494 }
495 }
496
497
498 static MethodHandle opHandle(MethodHandles.Lookup l, String opName, FunctionType ft) {
499 MethodType mt = resolveToMethodType(l, ft).erase();
500 try {
501 return MethodHandles.lookup().findStatic(InvokableLeafOps.class, opName, mt);
502 } catch (NoSuchMethodException | IllegalAccessException e) {
503 throw evaluationException(e);
504 }
505 }
506
507 static MethodHandle constructorHandle(MethodHandles.Lookup l, FunctionType ft) {
508 MethodType mt = resolveToMethodType(l, ft);
509
510 if (mt.returnType().isArray()) {
511 if (mt.parameterCount() != 1 || mt.parameterType(0) != int.class) {
512 throw evaluationException(new IllegalArgumentException("Bad constructor descriptor: " + ft));
513 }
514 return MethodHandles.arrayConstructor(mt.returnType());
515 } else {
516 try {
517 return l.findConstructor(mt.returnType(), mt.changeReturnType(void.class));
518 } catch (NoSuchMethodException | IllegalAccessException e) {
519 throw evaluationException(e);
520 }
521 }
522 }
523
524 static VarHandle fieldStaticHandle(MethodHandles.Lookup l, FieldRef d) {
525 return resolveToVarHandle(l, d);
526 }
527
528 static VarHandle fieldHandle(MethodHandles.Lookup l, FieldRef d) {
529 return resolveToVarHandle(l, d);
530 }
531
532 static Object isInstance(MethodHandles.Lookup l, TypeElement d, Object v) {
533 Class<?> c = resolveToClass(l, d);
534 return c.isInstance(v);
535 }
536
537 static Object cast(MethodHandles.Lookup l, TypeElement d, Object v) {
538 Class<?> c = resolveToClass(l, d);
539 return c.cast(v);
540 }
541
542 static MethodHandle resolveToMethodHandle(MethodHandles.Lookup l, MethodRef d, JavaOp.InvokeOp.InvokeKind kind) {
543 try {
544 return d.resolveToHandle(l, kind);
545 } catch (ReflectiveOperationException e) {
546 throw evaluationException(e);
547 }
548 }
549
550 static VarHandle resolveToVarHandle(MethodHandles.Lookup l, FieldRef d) {
551 try {
552 return d.resolveToHandle(l);
553 } catch (ReflectiveOperationException e) {
554 throw evaluationException(e);
555 }
556 }
557
558 public static MethodType resolveToMethodType(MethodHandles.Lookup l, FunctionType ft) {
559 try {
560 return MethodRef.toNominalDescriptor(ft).resolveConstantDesc(l);
561 } catch (ReflectiveOperationException e) {
562 throw evaluationException(e);
563 }
564 }
565
566 public static Class<?> resolveToClass(MethodHandles.Lookup l, TypeElement d) {
567 try {
568 if (d instanceof JavaType jt) {
569 return (Class<?>) jt.erasure().resolve(l);
570 } else {
571 throw new ReflectiveOperationException();
572 }
573 } catch (ReflectiveOperationException e) {
574 throw evaluationException(e);
575 }
576 }
577
578 static Object invoke(MethodHandle m, Object... args) {
579 try {
580 return m.invokeWithArguments(args);
581 } catch (RuntimeException | Error e) {
582 throw e;
583 } catch (Throwable e) {
584 eraseAndThrow(e);
585 throw new InternalError("should not reach here");
586 }
587 }
588 }