1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.*; 25 import jdk.incubator.code.dialect.core.CoreOp; 26 import jdk.incubator.code.dialect.core.FunctionType; 27 import jdk.incubator.code.dialect.java.*; 28 29 import java.lang.invoke.MethodHandle; 30 import java.lang.invoke.MethodHandles; 31 import java.lang.invoke.MethodType; 32 import java.lang.invoke.VarHandle; 33 import java.lang.reflect.Array; 34 import java.util.*; 35 import java.util.function.Predicate; 36 import java.util.stream.Collectors; 37 import java.util.stream.Stream; 38 39 final class PartialEvaluator { 40 final Set<Value> constants; 41 final Predicate<Op> opConstant; 42 43 PartialEvaluator(Set<Value> constants, Predicate<Op> opConstant) { 44 this.constants = new LinkedHashSet<>(constants); 45 this.opConstant = opConstant; 46 } 47 48 public static 49 CoreOp.FuncOp evaluate(MethodHandles.Lookup l, 50 Predicate<Op> opConstant, Set<Value> constants, 51 CoreOp.FuncOp op) { 52 PartialEvaluator pe = new PartialEvaluator(constants, opConstant); 53 Body.Builder outBody = pe.evaluateBody(l, op.body()); 54 return CoreOp.func(op.funcName(), outBody); 55 } 56 57 58 @SuppressWarnings("serial") 59 public static final class EvaluationException extends RuntimeException { 60 private EvaluationException(Throwable cause) { 61 super(cause); 62 } 63 } 64 65 static EvaluationException evaluationException(Throwable cause) { 66 return new EvaluationException(cause); 67 } 68 69 static final class BodyContext { 70 final BodyContext parent; 71 72 final Map<Block, List<Block>> evaluatedPredecessors; 73 final Map<Value, Object> evaluatedValues; 74 75 final Queue<Block> blockStack; 76 final BitSet visited; 77 78 BodyContext(Block entryBlock) { 79 this.parent = null; 80 81 this.evaluatedPredecessors = new HashMap<>(); 82 this.evaluatedValues = new HashMap<>(); 83 this.blockStack = new PriorityQueue<>(Comparator.comparingInt(Block::index)); 84 85 this.visited = new BitSet(); 86 } 87 88 Object getValue(Value v) { 89 Object rv = evaluatedValues.get(v); 90 if (rv != null) { 91 return rv; 92 } 93 94 throw evaluationException(new IllegalArgumentException("Undefined value: " + v)); 95 } 96 97 void setValue(Value v, Object o) { 98 evaluatedValues.put(v, o); 99 } 100 } 101 102 Body.Builder evaluateBody(MethodHandles.Lookup l, 103 Body inBody) { 104 Block inEntryBlock = inBody.entryBlock(); 105 106 Body.Builder outBody = Body.Builder.of(null, inBody.bodyType()); 107 Block.Builder outEntryBlock = outBody.entryBlock(); 108 109 CopyContext cc = outEntryBlock.context(); 110 cc.mapBlock(inEntryBlock, outEntryBlock); 111 cc.mapValues(inEntryBlock.parameters(), outEntryBlock.parameters()); 112 113 evaluateEntryBlock(l, inEntryBlock, outEntryBlock, new BodyContext(inEntryBlock)); 114 115 return outBody; 116 } 117 118 void evaluateEntryBlock(MethodHandles.Lookup l, 119 Block inEntryBlock, 120 Block.Builder outEntryBlock, 121 BodyContext bc) { 122 assert inEntryBlock.isEntryBlock(); 123 124 Map<Block, LoopAnalyzer.Loop> loops = new HashMap<>(); 125 Set<Block> loopNoPeeling = new HashSet<>(); 126 127 // The first block cannot have any successors so the queue will have at least one entry 128 bc.blockStack.add(inEntryBlock); 129 while (!bc.blockStack.isEmpty()) { 130 final Block inBlock = bc.blockStack.poll(); 131 if (bc.visited.get(inBlock.index())) { 132 continue; 133 } 134 bc.visited.set(inBlock.index()); 135 136 final Block.Builder outBlock = outEntryBlock.context().getBlock(inBlock); 137 138 nopeel: if (inBlock.predecessors().size() > 1 && bc.evaluatedPredecessors.get(inBlock).size() == 1) { 139 // If we reached to this block through just one evaluated predecessor 140 Block inBlockPred = bc.evaluatedPredecessors.get(inBlock).getFirst(); 141 Block.Reference inBlockRef = inBlockPred.terminatingOp().successors().stream() 142 .filter(r -> r.targetBlock() == inBlock) 143 .findFirst().get(); 144 List<Value> args = inBlockRef.arguments(); 145 List<Boolean> argConstant = args.stream().map(constants::contains).toList(); 146 147 LoopAnalyzer.Loop loop = loops.computeIfAbsent(inBlock, b -> LoopAnalyzer.isLoop(inBlock).orElse(null)); 148 if (loop != null && inBlockPred.isDominatedBy(loop.header())) { 149 // Entering loop header from latch 150 assert loop.latches().contains(inBlockPred); 151 152 // Linear constant path from each exiting block (or nearest evaluated present dominator) to loop header 153 boolean constantExits = true; 154 for (LoopAnalyzer.LoopExit loopExitPair : loop.exits()) { 155 Block loopExit = loopExitPair.exit(); 156 157 // Find nearest evaluated dominator 158 List<Block> ePreds = bc.evaluatedPredecessors.get(loopExit); 159 while (ePreds == null) { 160 loopExit = loopExit.immediateDominator(); 161 ePreds = bc.evaluatedPredecessors.get(loopExit); 162 } 163 assert loop.body().contains(loopExit); 164 165 if (ePreds.size() != 1 || 166 !(loopExit.terminatingOp() instanceof CoreOp.ConditionalBranchOp cbr) || 167 !constants.contains(cbr.result())) { 168 // If there are multiple encounters, or terminal op is not a constant conditional branch 169 constantExits = false; 170 break; 171 } 172 } 173 174 // Determine if constant args, before reset 175 boolean constantArgs = constants.containsAll(args); 176 177 // Reset state within loop body 178 for (Block block : loop.body()) { 179 // Reset visits, but not for loop header 180 if (block != loop.header()) { 181 bc.evaluatedPredecessors.remove(block); 182 bc.visited.set(block.index(), false); 183 } 184 185 // Reset constants 186 for (Op op : block.ops()) { 187 constants.remove(op.result()); 188 } 189 constants.removeAll(block.parameters()); 190 191 // Reset no peeling for any nested loops 192 loopNoPeeling.remove(block); 193 } 194 195 if (!constantExits || !constantArgs) { 196 // Finish peeling 197 // No constant exit and no constant args 198 loopNoPeeling.addAll(loop.latches()); 199 break nopeel; 200 } 201 // Peel next iteration 202 } 203 204 // Propagate constant arguments 205 for (int i = 0; i < args.size(); i++) { 206 Value inArgument = args.get(i); 207 if (argConstant.get(i)) { 208 Block.Parameter inParameter = inBlock.parameters().get(i); 209 210 // Map input parameter to output argument 211 outBlock.context().mapValue(inParameter, outBlock.context().getValue(inArgument)); 212 // Set parameter constant 213 constants.add(inParameter); 214 bc.setValue(inParameter, bc.getValue(inArgument)); 215 } 216 } 217 } 218 219 // Process all but the terminating operation 220 int nops = inBlock.ops().size(); 221 for (int i = 0; i < nops - 1; i++) { 222 Op op = inBlock.ops().get(i); 223 224 if (isConstant(op)) { 225 // Evaluate operation 226 // @@@ Handle exceptions 227 Object result = interpretOp(l, bc, op); 228 bc.setValue(op.result(), result); 229 230 if (op instanceof CoreOp.VarOp) { 231 // @@@ Do not turn into constant to avoid conflicts with the interpreter 232 // and its runtime representation of vars 233 outBlock.op(op); 234 } else { 235 // Result was evaluated, replace with constant operation 236 Op.Result constantResult = outBlock.op(CoreOp.constant(op.resultType(), result)); 237 outBlock.context().mapValue(op.result(), constantResult); 238 } 239 } else { 240 // Copy unevaluated operation 241 Op.Result r = outBlock.op(op); 242 // Explicitly remap result, since the op can be copied more than once in pealed loops 243 // @@@ See comment Block.op code which implicitly limits this 244 outBlock.context().mapValue(op.result(), r); 245 } 246 } 247 248 // Process the terminating operation 249 Op to = inBlock.terminatingOp(); 250 switch (to) { 251 case CoreOp.ConditionalBranchOp cb -> { 252 if (isConstant(to)) { 253 boolean p = switch (bc.getValue(cb.predicate())) { 254 case Boolean bp -> bp; 255 case Integer ip -> 256 // @@@ This is required when lifting up from bytecode, since boolean values 257 // are erased to int values, abd the bytecode lifting implementation is not currently 258 // sophisticated enough to recover the type information 259 ip != 0; 260 default -> throw evaluationException( 261 new UnsupportedOperationException("Unsupported type input to operation: " + cb)); 262 }; 263 264 Block.Reference nextInBlockRef = p ? cb.trueBranch() : cb.falseBranch(); 265 Block nextInBlock = nextInBlockRef.targetBlock(); 266 267 // @@@ might be latch to loop 268 assert !inBlock.isDominatedBy(nextInBlock); 269 270 processBlock(bc, inBlock, nextInBlock, outBlock); 271 272 outBlock.op(CoreOp.branch(outBlock.context().getSuccessorOrCreate(nextInBlockRef))); 273 } else { 274 // @@@ might be non-constant latch to loop 275 processBlock(bc, inBlock, cb.falseBranch().targetBlock(), outBlock); 276 processBlock(bc, inBlock, cb.trueBranch().targetBlock(), outBlock); 277 278 outBlock.op(to); 279 } 280 } 281 case CoreOp.BranchOp b -> { 282 Block.Reference nextInBlockRef = b.branch(); 283 Block nextInBlock = nextInBlockRef.targetBlock(); 284 285 if (inBlock.isDominatedBy(nextInBlock)) { 286 // latch to loop header 287 assert bc.visited.get(nextInBlock.index()); 288 if (!loopNoPeeling.contains(inBlock) && constants.containsAll(nextInBlock.parameters())) { 289 // Reset loop body to peel off another iteration 290 bc.visited.set(nextInBlock.index(), false); 291 bc.evaluatedPredecessors.remove(nextInBlock); 292 } 293 } 294 295 processBlock(bc, inBlock, nextInBlock, outBlock); 296 297 outBlock.op(b); 298 } 299 case CoreOp.ReturnOp _ -> outBlock.op(to); 300 default -> throw evaluationException( 301 new UnsupportedOperationException("Unsupported terminating operation: " + to.opName())); 302 } 303 } 304 } 305 306 boolean isConstant(Op op) { 307 if (constants.contains(op.result())) { 308 return true; 309 } else if (constants.containsAll(op.operands()) && opConstant.test(op)) { 310 constants.add(op.result()); 311 return true; 312 } else { 313 return false; 314 } 315 } 316 317 void processBlock(BodyContext bc, Block inBlock, Block nextInBlock, Block.Builder outBlock) { 318 bc.blockStack.add(nextInBlock); 319 if (!bc.evaluatedPredecessors.containsKey(nextInBlock)) { 320 // Copy block 321 Block.Builder nextOutBlock = outBlock.block(nextInBlock.parameterTypes()); 322 outBlock.context().mapBlock(nextInBlock, nextOutBlock); 323 outBlock.context().mapValues(nextInBlock.parameters(), nextOutBlock.parameters()); 324 } 325 bc.evaluatedPredecessors.computeIfAbsent(nextInBlock, _ -> new ArrayList<>()).add(inBlock); 326 } 327 328 @SuppressWarnings("unchecked") 329 public static <E extends Throwable> void eraseAndThrow(Throwable e) throws E { 330 throw (E) e; 331 } 332 333 // @@@ This could be shared with the interpreter if it was more extensible 334 Object interpretOp(MethodHandles.Lookup l, BodyContext bc, Op o) { 335 switch (o) { 336 case CoreOp.ConstantOp co -> { 337 if (co.resultType().equals(JavaType.J_L_CLASS)) { 338 return resolveToClass(l, (JavaType) co.value()); 339 } else { 340 return co.value(); 341 } 342 } 343 case JavaOp.InvokeOp co -> { 344 MethodType target = resolveToMethodType(l, o.opType()); 345 MethodHandles.Lookup il = switch (co.invokeKind()) { 346 case STATIC, INSTANCE -> l; 347 case SUPER -> l.in(target.parameterType(0)); 348 }; 349 MethodHandle mh = resolveToMethodHandle(il, co.invokeDescriptor(), co.invokeKind()); 350 351 mh = mh.asType(target).asFixedArity(); 352 Object[] values = o.operands().stream().map(bc::getValue).toArray(); 353 return invoke(mh, values); 354 } 355 case JavaOp.NewOp no -> { 356 Object[] values = o.operands().stream().map(bc::getValue).toArray(); 357 JavaType nType = (JavaType) no.resultType(); 358 if (nType instanceof ArrayType at) { 359 if (values.length > at.dimensions()) { 360 throw evaluationException(new IllegalArgumentException("Bad constructor NewOp: " + no)); 361 } 362 int[] lengths = Stream.of(values).mapToInt(v -> (int) v).toArray(); 363 for (int length : lengths) { 364 nType = ((ArrayType) nType).componentType(); 365 } 366 return Array.newInstance(resolveToClass(l, nType), lengths); 367 } else { 368 MethodHandle mh = constructorHandle(l, no.constructorDescriptor().type()); 369 return invoke(mh, values); 370 } 371 } 372 case CoreOp.VarOp vo -> { 373 Object[] vbox = vo.isUninitialized() 374 ? new Object[] { null, false } 375 : new Object[] { bc.getValue(o.operands().get(0)) }; 376 return vbox; 377 } 378 case CoreOp.VarAccessOp.VarLoadOp vlo -> { 379 // Cast to CoreOp.Var, since the instance may have originated as an external instance 380 // via a captured value map 381 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0)); 382 if (vbox.length == 2 && !((Boolean) vbox[1])) { 383 throw evaluationException(new IllegalStateException("Loading from uninitialized variable")); 384 } 385 return vbox[0]; 386 } 387 case CoreOp.VarAccessOp.VarStoreOp vso -> { 388 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0)); 389 if (vbox.length == 2) { 390 vbox[1] = true; 391 } 392 vbox[0] = bc.getValue(o.operands().get(1)); 393 return null; 394 } 395 case CoreOp.TupleOp to -> { 396 return o.operands().stream().map(bc::getValue).toList(); 397 } 398 case CoreOp.TupleLoadOp tlo -> { 399 @SuppressWarnings("unchecked") 400 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0)); 401 return tb.get(tlo.index()); 402 } 403 case CoreOp.TupleWithOp two -> { 404 @SuppressWarnings("unchecked") 405 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0)); 406 List<Object> copy = new ArrayList<>(tb); 407 copy.set(two.index(), bc.getValue(o.operands().get(1))); 408 return Collections.unmodifiableList(copy); 409 } 410 case JavaOp.FieldAccessOp.FieldLoadOp fo -> { 411 if (fo.operands().isEmpty()) { 412 VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor()); 413 return vh.get(); 414 } else { 415 Object v = bc.getValue(o.operands().get(0)); 416 VarHandle vh = fieldHandle(l, fo.fieldDescriptor()); 417 return vh.get(v); 418 } 419 } 420 case JavaOp.FieldAccessOp.FieldStoreOp fo -> { 421 if (fo.operands().size() == 1) { 422 Object v = bc.getValue(o.operands().get(0)); 423 VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor()); 424 vh.set(v); 425 } else { 426 Object r = bc.getValue(o.operands().get(0)); 427 Object v = bc.getValue(o.operands().get(1)); 428 VarHandle vh = fieldHandle(l, fo.fieldDescriptor()); 429 vh.set(r, v); 430 } 431 return null; 432 } 433 case JavaOp.InstanceOfOp io -> { 434 Object v = bc.getValue(o.operands().get(0)); 435 return isInstance(l, io.type(), v); 436 } 437 case JavaOp.CastOp co -> { 438 Object v = bc.getValue(o.operands().get(0)); 439 return cast(l, co.type(), v); 440 } 441 case JavaOp.ArrayLengthOp arrayLengthOp -> { 442 Object a = bc.getValue(o.operands().get(0)); 443 return Array.getLength(a); 444 } 445 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> { 446 Object a = bc.getValue(o.operands().get(0)); 447 Object index = bc.getValue(o.operands().get(1)); 448 return Array.get(a, (int) index); 449 } 450 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> { 451 Object a = bc.getValue(o.operands().get(0)); 452 Object index = bc.getValue(o.operands().get(1)); 453 Object v = bc.getValue(o.operands().get(2)); 454 Array.set(a, (int) index, v); 455 return null; 456 } 457 case JavaOp.ArithmeticOperation arithmeticOperation -> { 458 MethodHandle mh = opHandle(l, o.opName(), o.opType()); 459 Object[] values = o.operands().stream().map(bc::getValue).toArray(); 460 return invoke(mh, values); 461 } 462 case JavaOp.TestOperation testOperation -> { 463 MethodHandle mh = opHandle(l, o.opName(), o.opType()); 464 Object[] values = o.operands().stream().map(bc::getValue).toArray(); 465 return invoke(mh, values); 466 } 467 case JavaOp.ConvOp convOp -> { 468 MethodHandle mh = opHandle(l, o.opName() + "_" + o.opType().returnType(), o.opType()); 469 Object[] values = o.operands().stream().map(bc::getValue).toArray(); 470 return invoke(mh, values); 471 } 472 case JavaOp.ConcatOp concatOp -> { 473 return o.operands().stream() 474 .map(bc::getValue) 475 .map(String::valueOf) 476 .collect(Collectors.joining()); 477 } 478 // @@@ 479 // case CoreOp.LambdaOp lambdaOp -> { 480 // interpretEntryBlock(l, lambdaOp.body().entryBlock(), oc, new HashMap<>()); 481 // unevaluatedOperations.add(o); 482 // return null; 483 // } 484 // case CoreOp.FuncOp funcOp -> { 485 // interpretEntryBlock(l, funcOp.body().entryBlock(), oc, new HashMap<>()); 486 // unevaluatedOperations.add(o); 487 // return null; 488 // } 489 case null, default -> throw evaluationException( 490 new UnsupportedOperationException("Unsupported operation: " + o.opName())); 491 } 492 } 493 494 495 static MethodHandle opHandle(MethodHandles.Lookup l, String opName, FunctionType ft) { 496 MethodType mt = resolveToMethodType(l, ft).erase(); 497 try { 498 return MethodHandles.lookup().findStatic(InvokableLeafOps.class, opName, mt); 499 } catch (NoSuchMethodException | IllegalAccessException e) { 500 throw evaluationException(e); 501 } 502 } 503 504 static MethodHandle constructorHandle(MethodHandles.Lookup l, FunctionType ft) { 505 MethodType mt = resolveToMethodType(l, ft); 506 507 if (mt.returnType().isArray()) { 508 if (mt.parameterCount() != 1 || mt.parameterType(0) != int.class) { 509 throw evaluationException(new IllegalArgumentException("Bad constructor descriptor: " + ft)); 510 } 511 return MethodHandles.arrayConstructor(mt.returnType()); 512 } else { 513 try { 514 return l.findConstructor(mt.returnType(), mt.changeReturnType(void.class)); 515 } catch (NoSuchMethodException | IllegalAccessException e) { 516 throw evaluationException(e); 517 } 518 } 519 } 520 521 static VarHandle fieldStaticHandle(MethodHandles.Lookup l, FieldRef d) { 522 return resolveToVarHandle(l, d); 523 } 524 525 static VarHandle fieldHandle(MethodHandles.Lookup l, FieldRef d) { 526 return resolveToVarHandle(l, d); 527 } 528 529 static Object isInstance(MethodHandles.Lookup l, TypeElement d, Object v) { 530 Class<?> c = resolveToClass(l, d); 531 return c.isInstance(v); 532 } 533 534 static Object cast(MethodHandles.Lookup l, TypeElement d, Object v) { 535 Class<?> c = resolveToClass(l, d); 536 return c.cast(v); 537 } 538 539 static MethodHandle resolveToMethodHandle(MethodHandles.Lookup l, MethodRef d, JavaOp.InvokeOp.InvokeKind kind) { 540 try { 541 return d.resolveToHandle(l, kind); 542 } catch (ReflectiveOperationException e) { 543 throw evaluationException(e); 544 } 545 } 546 547 static VarHandle resolveToVarHandle(MethodHandles.Lookup l, FieldRef d) { 548 try { 549 return d.resolveToHandle(l); 550 } catch (ReflectiveOperationException e) { 551 throw evaluationException(e); 552 } 553 } 554 555 public static MethodType resolveToMethodType(MethodHandles.Lookup l, FunctionType ft) { 556 try { 557 return MethodRef.toNominalDescriptor(ft).resolveConstantDesc(l); 558 } catch (ReflectiveOperationException e) { 559 throw evaluationException(e); 560 } 561 } 562 563 public static Class<?> resolveToClass(MethodHandles.Lookup l, TypeElement d) { 564 try { 565 if (d instanceof JavaType jt) { 566 return (Class<?>) jt.erasure().resolve(l); 567 } else { 568 throw new ReflectiveOperationException(); 569 } 570 } catch (ReflectiveOperationException e) { 571 throw evaluationException(e); 572 } 573 } 574 575 static Object invoke(MethodHandle m, Object... args) { 576 try { 577 return m.invokeWithArguments(args); 578 } catch (RuntimeException | Error e) { 579 throw e; 580 } catch (Throwable e) { 581 eraseAndThrow(e); 582 throw new InternalError("should not reach here"); 583 } 584 } 585 }