1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.CoreOp;
 26 import jdk.incubator.code.dialect.core.FunctionType;
 27 import jdk.incubator.code.dialect.java.*;
 28 
 29 import java.lang.invoke.MethodHandle;
 30 import java.lang.invoke.MethodHandles;
 31 import java.lang.invoke.MethodType;
 32 import java.lang.invoke.VarHandle;
 33 import java.lang.reflect.Array;
 34 import java.util.*;
 35 import java.util.function.Predicate;
 36 import java.util.stream.Collectors;
 37 import java.util.stream.Stream;
 38 
 39 final class PartialEvaluator {
 40     final Set<Value> constants;
 41     final Predicate<Op> opConstant;
 42 
 43     PartialEvaluator(Set<Value> constants, Predicate<Op> opConstant) {
 44         this.constants = new LinkedHashSet<>(constants);
 45         this.opConstant = opConstant;
 46     }
 47 
 48     public static
 49     CoreOp.FuncOp evaluate(MethodHandles.Lookup l,
 50                            Predicate<Op> opConstant, Set<Value> constants,
 51                            CoreOp.FuncOp op) {
 52         PartialEvaluator pe = new PartialEvaluator(constants, opConstant);
 53         Body.Builder outBody = pe.evaluateBody(l, op.body());
 54         return CoreOp.func(op.funcName(), outBody);
 55     }
 56 
 57 
 58     @SuppressWarnings("serial")
 59     public static final class EvaluationException extends RuntimeException {
 60         private EvaluationException(Throwable cause) {
 61             super(cause);
 62         }
 63     }
 64 
 65     static EvaluationException evaluationException(Throwable cause) {
 66         return new EvaluationException(cause);
 67     }
 68 
 69     static final class BodyContext {
 70         final BodyContext parent;
 71 
 72         final Map<Block, List<Block>> evaluatedPredecessors;
 73         final Map<Value, Object> evaluatedValues;
 74 
 75         final Queue<Block> blockStack;
 76         final BitSet visited;
 77 
 78         BodyContext(Block entryBlock) {
 79             this.parent = null;
 80 
 81             this.evaluatedPredecessors = new HashMap<>();
 82             this.evaluatedValues = new HashMap<>();
 83             this.blockStack = new PriorityQueue<>(Comparator.comparingInt(Block::index));
 84 
 85             this.visited = new BitSet();
 86         }
 87 
 88         Object getValue(Value v) {
 89             Object rv = evaluatedValues.get(v);
 90             if (rv != null) {
 91                 return rv;
 92             }
 93 
 94             throw evaluationException(new IllegalArgumentException("Undefined value: " + v));
 95         }
 96 
 97         void setValue(Value v, Object o) {
 98             evaluatedValues.put(v, o);
 99         }
100     }
101 
102     Body.Builder evaluateBody(MethodHandles.Lookup l,
103                               Body inBody) {
104         Block inEntryBlock = inBody.entryBlock();
105 
106         Body.Builder outBody = Body.Builder.of(null, inBody.bodyType());
107         Block.Builder outEntryBlock = outBody.entryBlock();
108 
109         CopyContext cc = outEntryBlock.context();
110         cc.mapBlock(inEntryBlock, outEntryBlock);
111         cc.mapValues(inEntryBlock.parameters(), outEntryBlock.parameters());
112 
113         evaluateEntryBlock(l, inEntryBlock, outEntryBlock, new BodyContext(inEntryBlock));
114 
115         return outBody;
116     }
117 
118     void evaluateEntryBlock(MethodHandles.Lookup l,
119                             Block inEntryBlock,
120                             Block.Builder outEntryBlock,
121                             BodyContext bc) {
122         assert inEntryBlock.isEntryBlock();
123 
124         Map<Block, LoopAnalyzer.Loop> loops = new HashMap<>();
125         Set<Block> loopNoPeeling = new HashSet<>();
126 
127         // The first block cannot have any successors so the queue will have at least one entry
128         bc.blockStack.add(inEntryBlock);
129         while (!bc.blockStack.isEmpty()) {
130             final Block inBlock = bc.blockStack.poll();
131             if (bc.visited.get(inBlock.index())) {
132                 continue;
133             }
134             bc.visited.set(inBlock.index());
135 
136             final Block.Builder outBlock = outEntryBlock.context().getBlock(inBlock);
137 
138             nopeel: if (inBlock.predecessors().size() > 1 && bc.evaluatedPredecessors.get(inBlock).size() == 1) {
139                 // If we reached to this block through just one evaluated predecessor
140                 Block inBlockPred = bc.evaluatedPredecessors.get(inBlock).getFirst();
141                 Block.Reference inBlockRef = inBlockPred.terminatingOp().successors().stream()
142                         .filter(r -> r.targetBlock() == inBlock)
143                         .findFirst().get();
144                 List<Value> args = inBlockRef.arguments();
145                 List<Boolean> argConstant = args.stream().map(constants::contains).toList();
146 
147                 LoopAnalyzer.Loop loop = loops.computeIfAbsent(inBlock, b -> LoopAnalyzer.isLoop(inBlock).orElse(null));
148                 if (loop != null && inBlockPred.isDominatedBy(loop.header())) {
149                     // Entering loop header from latch
150                     assert loop.latches().contains(inBlockPred);
151 
152                     // Linear constant path from each exiting block (or nearest evaluated present dominator) to loop header
153                     boolean constantExits = true;
154                     for (LoopAnalyzer.LoopExit loopExitPair : loop.exits()) {
155                         Block loopExit = loopExitPair.exit();
156 
157                         // Find nearest evaluated dominator
158                         List<Block> ePreds = bc.evaluatedPredecessors.get(loopExit);
159                         while (ePreds == null) {
160                             loopExit = loopExit.immediateDominator();
161                             ePreds = bc.evaluatedPredecessors.get(loopExit);
162                         }
163                         assert loop.body().contains(loopExit);
164 
165                         if (ePreds.size() != 1 ||
166                                 !(loopExit.terminatingOp() instanceof CoreOp.ConditionalBranchOp cbr) ||
167                                 !constants.contains(cbr.result())) {
168                             // If there are multiple encounters, or terminal op is not a constant conditional branch
169                             constantExits = false;
170                             break;
171                         }
172                     }
173 
174                     // Determine if constant args, before reset
175                     boolean constantArgs = constants.containsAll(args);
176 
177                     // Reset state within loop body
178                     for (Block block : loop.body()) {
179                         // Reset visits, but not for loop header
180                         if (block != loop.header()) {
181                             bc.evaluatedPredecessors.remove(block);
182                             bc.visited.set(block.index(), false);
183                         }
184 
185                         // Reset constants
186                         for (Op op : block.ops()) {
187                             constants.remove(op.result());
188                         }
189                         constants.removeAll(block.parameters());
190 
191                         // Reset no peeling for any nested loops
192                         loopNoPeeling.remove(block);
193                     }
194 
195                     if (!constantExits || !constantArgs) {
196                         // Finish peeling
197                         // No constant exit and no constant args
198                         loopNoPeeling.addAll(loop.latches());
199                         break nopeel;
200                     }
201                     // Peel next iteration
202                 }
203 
204                 // Propagate constant arguments
205                 for (int i = 0; i < args.size(); i++) {
206                     Value inArgument = args.get(i);
207                     if (argConstant.get(i)) {
208                         Block.Parameter inParameter = inBlock.parameters().get(i);
209 
210                         // Map input parameter to output argument
211                         outBlock.context().mapValue(inParameter, outBlock.context().getValue(inArgument));
212                         // Set parameter constant
213                         constants.add(inParameter);
214                         bc.setValue(inParameter, bc.getValue(inArgument));
215                     }
216                 }
217             }
218 
219             // Process all but the terminating operation
220             int nops = inBlock.ops().size();
221             for (int i = 0; i < nops - 1; i++) {
222                 Op op = inBlock.ops().get(i);
223 
224                 if (isConstant(op)) {
225                     // Evaluate operation
226                     // @@@ Handle exceptions
227                     Object result = interpretOp(l, bc, op);
228                     bc.setValue(op.result(), result);
229 
230                     if (op instanceof CoreOp.VarOp) {
231                         // @@@ Do not turn into constant to avoid conflicts with the interpreter
232                         // and its runtime representation of vars
233                         outBlock.op(op);
234                     } else {
235                         // Result was evaluated, replace with constant operation
236                         Op.Result constantResult = outBlock.op(CoreOp.constant(op.resultType(), result));
237                         outBlock.context().mapValue(op.result(), constantResult);
238                     }
239                 } else {
240                     // Copy unevaluated operation
241                     Op.Result r = outBlock.op(op);
242                     // Explicitly remap result, since the op can be copied more than once in pealed loops
243                     // @@@ See comment Block.op code which implicitly limits this
244                     outBlock.context().mapValue(op.result(), r);
245                 }
246             }
247 
248             // Process the terminating operation
249             Op to = inBlock.terminatingOp();
250             switch (to) {
251                 case CoreOp.ConditionalBranchOp cb -> {
252                     if (isConstant(to)) {
253                         boolean p = switch (bc.getValue(cb.predicate())) {
254                             case Boolean bp -> bp;
255                             case Integer ip ->
256                                 // @@@ This is required when lifting up from bytecode, since boolean values
257                                 // are erased to int values, abd the bytecode lifting implementation is not currently
258                                 // sophisticated enough to recover the type information
259                                     ip != 0;
260                             default -> throw evaluationException(
261                                     new UnsupportedOperationException("Unsupported type input to operation: " + cb));
262                         };
263 
264                         Block.Reference nextInBlockRef = p ? cb.trueBranch() : cb.falseBranch();
265                         Block nextInBlock = nextInBlockRef.targetBlock();
266 
267                         // @@@ might be latch to loop
268                         assert !inBlock.isDominatedBy(nextInBlock);
269 
270                         processBlock(bc, inBlock, nextInBlock, outBlock);
271 
272                         outBlock.op(CoreOp.branch(outBlock.context().getSuccessorOrCreate(nextInBlockRef)));
273                     } else {
274                         // @@@ might be non-constant latch to loop
275                         processBlock(bc, inBlock, cb.falseBranch().targetBlock(), outBlock);
276                         processBlock(bc, inBlock, cb.trueBranch().targetBlock(), outBlock);
277 
278                         outBlock.op(to);
279                     }
280                 }
281                 case CoreOp.BranchOp b -> {
282                     Block.Reference nextInBlockRef = b.branch();
283                     Block nextInBlock = nextInBlockRef.targetBlock();
284 
285                     if (inBlock.isDominatedBy(nextInBlock)) {
286                         // latch to loop header
287                         assert bc.visited.get(nextInBlock.index());
288                         if (!loopNoPeeling.contains(inBlock) && constants.containsAll(nextInBlock.parameters())) {
289                             // Reset loop body to peel off another iteration
290                             bc.visited.set(nextInBlock.index(), false);
291                             bc.evaluatedPredecessors.remove(nextInBlock);
292                         }
293                     }
294 
295                     processBlock(bc, inBlock, nextInBlock, outBlock);
296 
297                     outBlock.op(b);
298                 }
299                 case CoreOp.ReturnOp _ -> outBlock.op(to);
300                 default -> throw evaluationException(
301                         new UnsupportedOperationException("Unsupported terminating operation: " + to.opName()));
302             }
303         }
304     }
305 
306     boolean isConstant(Op op) {
307         if (constants.contains(op.result())) {
308             return true;
309         } else if (constants.containsAll(op.operands()) && opConstant.test(op)) {
310             constants.add(op.result());
311             return true;
312         } else {
313             return false;
314         }
315     }
316 
317     void processBlock(BodyContext bc, Block inBlock, Block nextInBlock, Block.Builder outBlock) {
318         bc.blockStack.add(nextInBlock);
319         if (!bc.evaluatedPredecessors.containsKey(nextInBlock)) {
320             // Copy block
321             Block.Builder nextOutBlock = outBlock.block(nextInBlock.parameterTypes());
322             outBlock.context().mapBlock(nextInBlock, nextOutBlock);
323             outBlock.context().mapValues(nextInBlock.parameters(), nextOutBlock.parameters());
324         }
325         bc.evaluatedPredecessors.computeIfAbsent(nextInBlock, _ -> new ArrayList<>()).add(inBlock);
326     }
327 
328     @SuppressWarnings("unchecked")
329     public static <E extends Throwable> void eraseAndThrow(Throwable e) throws E {
330         throw (E) e;
331     }
332 
333     // @@@ This could be shared with the interpreter if it was more extensible
334     Object interpretOp(MethodHandles.Lookup l, BodyContext bc, Op o) {
335         switch (o) {
336             case CoreOp.ConstantOp co -> {
337                 if (co.resultType().equals(JavaType.J_L_CLASS)) {
338                     return resolveToClass(l, (JavaType) co.value());
339                 } else {
340                     return co.value();
341                 }
342             }
343             case JavaOp.InvokeOp co -> {
344                 MethodType target = resolveToMethodType(l, o.opType());
345                 MethodHandles.Lookup il = switch (co.invokeKind()) {
346                     case STATIC, INSTANCE -> l;
347                     case SUPER -> l.in(target.parameterType(0));
348                 };
349                 MethodHandle mh = resolveToMethodHandle(il, co.invokeDescriptor(), co.invokeKind());
350 
351                 mh = mh.asType(target).asFixedArity();
352                 Object[] values = o.operands().stream().map(bc::getValue).toArray();
353                 return invoke(mh, values);
354             }
355             case JavaOp.NewOp no -> {
356                 Object[] values = o.operands().stream().map(bc::getValue).toArray();
357                 JavaType nType = (JavaType) no.resultType();
358                 if (nType instanceof ArrayType at) {
359                     if (values.length > at.dimensions()) {
360                         throw evaluationException(new IllegalArgumentException("Bad constructor NewOp: " + no));
361                     }
362                     int[] lengths = Stream.of(values).mapToInt(v -> (int) v).toArray();
363                     for (int length : lengths) {
364                         nType = ((ArrayType) nType).componentType();
365                     }
366                     return Array.newInstance(resolveToClass(l, nType), lengths);
367                 } else {
368                     MethodHandle mh = constructorHandle(l, no.constructorDescriptor().type());
369                     return invoke(mh, values);
370                 }
371             }
372             case CoreOp.VarOp vo -> {
373                 Object[] vbox = vo.isUninitialized()
374                         ? new Object[] { null, false }
375                         : new Object[] { bc.getValue(o.operands().get(0)) };
376                 return vbox;
377             }
378             case CoreOp.VarAccessOp.VarLoadOp vlo -> {
379                 // Cast to CoreOp.Var, since the instance may have originated as an external instance
380                 // via a captured value map
381                 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0));
382                 if (vbox.length == 2 && !((Boolean) vbox[1])) {
383                     throw evaluationException(new IllegalStateException("Loading from uninitialized variable"));
384                 }
385                 return vbox[0];
386             }
387             case CoreOp.VarAccessOp.VarStoreOp vso -> {
388                 Object[] vbox = (Object[]) bc.getValue(o.operands().get(0));
389                 if (vbox.length == 2) {
390                     vbox[1] = true;
391                 }
392                 vbox[0] = bc.getValue(o.operands().get(1));
393                 return null;
394             }
395             case CoreOp.TupleOp to -> {
396                 return o.operands().stream().map(bc::getValue).toList();
397             }
398             case CoreOp.TupleLoadOp tlo -> {
399                 @SuppressWarnings("unchecked")
400                 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0));
401                 return tb.get(tlo.index());
402             }
403             case CoreOp.TupleWithOp two -> {
404                 @SuppressWarnings("unchecked")
405                 List<Object> tb = (List<Object>) bc.getValue(o.operands().get(0));
406                 List<Object> copy = new ArrayList<>(tb);
407                 copy.set(two.index(), bc.getValue(o.operands().get(1)));
408                 return Collections.unmodifiableList(copy);
409             }
410             case JavaOp.FieldAccessOp.FieldLoadOp fo -> {
411                 if (fo.operands().isEmpty()) {
412                     VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor());
413                     return vh.get();
414                 } else {
415                     Object v = bc.getValue(o.operands().get(0));
416                     VarHandle vh = fieldHandle(l, fo.fieldDescriptor());
417                     return vh.get(v);
418                 }
419             }
420             case JavaOp.FieldAccessOp.FieldStoreOp fo -> {
421                 if (fo.operands().size() == 1) {
422                     Object v = bc.getValue(o.operands().get(0));
423                     VarHandle vh = fieldStaticHandle(l, fo.fieldDescriptor());
424                     vh.set(v);
425                 } else {
426                     Object r = bc.getValue(o.operands().get(0));
427                     Object v = bc.getValue(o.operands().get(1));
428                     VarHandle vh = fieldHandle(l, fo.fieldDescriptor());
429                     vh.set(r, v);
430                 }
431                 return null;
432             }
433             case JavaOp.InstanceOfOp io -> {
434                 Object v = bc.getValue(o.operands().get(0));
435                 return isInstance(l, io.type(), v);
436             }
437             case JavaOp.CastOp co -> {
438                 Object v = bc.getValue(o.operands().get(0));
439                 return cast(l, co.type(), v);
440             }
441             case JavaOp.ArrayLengthOp arrayLengthOp -> {
442                 Object a = bc.getValue(o.operands().get(0));
443                 return Array.getLength(a);
444             }
445             case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
446                 Object a = bc.getValue(o.operands().get(0));
447                 Object index = bc.getValue(o.operands().get(1));
448                 return Array.get(a, (int) index);
449             }
450             case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
451                 Object a = bc.getValue(o.operands().get(0));
452                 Object index = bc.getValue(o.operands().get(1));
453                 Object v = bc.getValue(o.operands().get(2));
454                 Array.set(a, (int) index, v);
455                 return null;
456             }
457             case JavaOp.ArithmeticOperation arithmeticOperation -> {
458                 MethodHandle mh = opHandle(l, o.opName(), o.opType());
459                 Object[] values = o.operands().stream().map(bc::getValue).toArray();
460                 return invoke(mh, values);
461             }
462             case JavaOp.TestOperation testOperation -> {
463                 MethodHandle mh = opHandle(l, o.opName(), o.opType());
464                 Object[] values = o.operands().stream().map(bc::getValue).toArray();
465                 return invoke(mh, values);
466             }
467             case JavaOp.ConvOp convOp -> {
468                 MethodHandle mh = opHandle(l, o.opName() + "_" + o.opType().returnType(), o.opType());
469                 Object[] values = o.operands().stream().map(bc::getValue).toArray();
470                 return invoke(mh, values);
471             }
472             case JavaOp.ConcatOp concatOp -> {
473                 return o.operands().stream()
474                         .map(bc::getValue)
475                         .map(String::valueOf)
476                         .collect(Collectors.joining());
477             }
478             // @@@
479 //            case CoreOp.LambdaOp lambdaOp -> {
480 //                interpretEntryBlock(l, lambdaOp.body().entryBlock(), oc, new HashMap<>());
481 //                unevaluatedOperations.add(o);
482 //                return null;
483 //            }
484 //            case CoreOp.FuncOp funcOp -> {
485 //                interpretEntryBlock(l, funcOp.body().entryBlock(), oc, new HashMap<>());
486 //                unevaluatedOperations.add(o);
487 //                return null;
488 //            }
489             case null, default -> throw evaluationException(
490                     new UnsupportedOperationException("Unsupported operation: " + o.opName()));
491         }
492     }
493 
494 
495     static MethodHandle opHandle(MethodHandles.Lookup l, String opName, FunctionType ft) {
496         MethodType mt = resolveToMethodType(l, ft).erase();
497         try {
498             return MethodHandles.lookup().findStatic(InvokableLeafOps.class, opName, mt);
499         } catch (NoSuchMethodException | IllegalAccessException e) {
500             throw evaluationException(e);
501         }
502     }
503 
504     static MethodHandle constructorHandle(MethodHandles.Lookup l, FunctionType ft) {
505         MethodType mt = resolveToMethodType(l, ft);
506 
507         if (mt.returnType().isArray()) {
508             if (mt.parameterCount() != 1 || mt.parameterType(0) != int.class) {
509                 throw evaluationException(new IllegalArgumentException("Bad constructor descriptor: " + ft));
510             }
511             return MethodHandles.arrayConstructor(mt.returnType());
512         } else {
513             try {
514                 return l.findConstructor(mt.returnType(), mt.changeReturnType(void.class));
515             } catch (NoSuchMethodException | IllegalAccessException e) {
516                 throw evaluationException(e);
517             }
518         }
519     }
520 
521     static VarHandle fieldStaticHandle(MethodHandles.Lookup l, FieldRef d) {
522         return resolveToVarHandle(l, d);
523     }
524 
525     static VarHandle fieldHandle(MethodHandles.Lookup l, FieldRef d) {
526         return resolveToVarHandle(l, d);
527     }
528 
529     static Object isInstance(MethodHandles.Lookup l, TypeElement d, Object v) {
530         Class<?> c = resolveToClass(l, d);
531         return c.isInstance(v);
532     }
533 
534     static Object cast(MethodHandles.Lookup l, TypeElement d, Object v) {
535         Class<?> c = resolveToClass(l, d);
536         return c.cast(v);
537     }
538 
539     static MethodHandle resolveToMethodHandle(MethodHandles.Lookup l, MethodRef d, JavaOp.InvokeOp.InvokeKind kind) {
540         try {
541             return d.resolveToHandle(l, kind);
542         } catch (ReflectiveOperationException e) {
543             throw evaluationException(e);
544         }
545     }
546 
547     static VarHandle resolveToVarHandle(MethodHandles.Lookup l, FieldRef d) {
548         try {
549             return d.resolveToHandle(l);
550         } catch (ReflectiveOperationException e) {
551             throw evaluationException(e);
552         }
553     }
554 
555     public static MethodType resolveToMethodType(MethodHandles.Lookup l, FunctionType ft) {
556         try {
557             return MethodRef.toNominalDescriptor(ft).resolveConstantDesc(l);
558         } catch (ReflectiveOperationException e) {
559             throw evaluationException(e);
560         }
561     }
562 
563     public static Class<?> resolveToClass(MethodHandles.Lookup l, TypeElement d) {
564         try {
565             if (d instanceof JavaType jt) {
566                 return (Class<?>) jt.erasure().resolve(l);
567             } else {
568                 throw new ReflectiveOperationException();
569             }
570         } catch (ReflectiveOperationException e) {
571             throw evaluationException(e);
572         }
573     }
574 
575     static Object invoke(MethodHandle m, Object... args) {
576         try {
577             return m.invokeWithArguments(args);
578         } catch (RuntimeException | Error e) {
579             throw e;
580         } catch (Throwable e) {
581             eraseAndThrow(e);
582             throw new InternalError("should not reach here");
583         }
584     }
585 }