1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.incubator.code.interpreter;
 27 
 28 import java.lang.invoke.MethodHandles;
 29 import jdk.incubator.code.*;
 30 import jdk.incubator.code.op.CoreOp;
 31 import jdk.incubator.code.type.JavaType;
 32 import jdk.incubator.code.writer.OpWriter;
 33 import java.util.ArrayList;
 34 import java.util.Collections;
 35 import java.util.HashMap;
 36 import java.util.List;
 37 import java.util.Map;
 38 
 39 public final class Verifier {
 40 
 41     public final class VerifyError {
 42 
 43         private final String message;
 44 
 45         public VerifyError(String message) {
 46             this.message = message;
 47         }
 48 
 49         public String getMessage() {
 50             return message;
 51         }
 52 
 53         public String getPrintedContext() {
 54             return toText(rootOp);
 55         }
 56 
 57         @Override
 58         public String toString() {
 59             return getMessage() + " in " + getPrintedContext();
 60         }
 61     }
 62 
 63     public static List<Verifier.VerifyError> verify(Op op) {
 64         return verify(MethodHandles.publicLookup(), op);
 65     }
 66 
 67     public static List<Verifier.VerifyError> verify(MethodHandles.Lookup l, Op op) {
 68         var verifier = new Verifier(l, op);
 69         verifier.verifyOps();
 70         verifier.verifyExceptionRegions();
 71         return verifier.errors == null ? List.of() : Collections.unmodifiableList(verifier.errors);
 72     }
 73 
 74 
 75     private final MethodHandles.Lookup lookup;
 76     private final Op rootOp;
 77     private OpWriter.CodeItemNamerOption namerOption;
 78     private List<Verifier.VerifyError> errors;
 79 
 80     private Verifier(MethodHandles.Lookup lookup, Op rootOp) {
 81         this.lookup = lookup;
 82         this.rootOp = rootOp;
 83     }
 84 
 85     private OpWriter.CodeItemNamerOption getNamer() {
 86         if (namerOption == null) {
 87             namerOption = OpWriter.CodeItemNamerOption.of(OpWriter.computeGlobalNames(rootOp));
 88         }
 89         return namerOption;
 90     }
 91 
 92     private String toText(Op op) {
 93         return OpWriter.toText(op, getNamer());
 94     }
 95 
 96     private String getName(CodeItem codeItem) {
 97         return getNamer().namer().apply(codeItem);
 98     }
 99 
100     private void error(String message, Object... args) {
101         if (errors == null) {
102             errors = new ArrayList<>();
103         }
104         for (int i = 0; i < args.length; i++) {
105             args[i] = toText(args[i]);
106         }
107         errors.add(new VerifyError(message.formatted(args)));
108     }
109 
110     private String toText(Object arg) {
111         return switch (arg) {
112             case Op op -> toText(op);
113             case Block b -> getName(b);
114             case Value v -> getName(v);
115             case List<?> l -> l.stream().map(this::toText).toList().toString();
116             default -> arg.toString();
117         };
118     }
119 
120     private void verifyOps() {
121         rootOp.traverse(null, CodeElement.opVisitor((_, op) -> {
122             // Verify operands declaration dominannce
123             for (var v : op.operands()) {
124                 if (!op.result().isDominatedBy(v)) {
125                     error("%s %s operand %s is not dominated by its declaration in %s", op.parentBlock(), op, v, v.declaringBlock());
126                 }
127             }
128 
129             // Verify individual Ops
130             switch (op) {
131                 case CoreOp.BranchOp br ->
132                     verifyBlockReferences(op, br.successors());
133                 case CoreOp.ConditionalBranchOp cbr ->
134                     verifyBlockReferences(op, cbr.successors());
135                 case CoreOp.ArithmeticOperation _, CoreOp.TestOperation _ ->
136                     verifyOpHandleExists(op, op.opName());
137                 case CoreOp.ConvOp _ -> {
138                     verifyOpHandleExists(op, op.opName() + "_" + op.opType().returnType());
139                 }
140                 default -> {}
141 
142             }
143             return null;
144         }));
145     }
146 
147     private void verifyBlockReferences(Op op, List<Block.Reference> references) {
148         for (Block.Reference r : references) {
149             Block b = r.targetBlock();
150             List<Value> args = r.arguments();
151             List<Block.Parameter> params = r.targetBlock().parameters();
152             if (args.size() != params.size()) {
153                 error("%s %s block reference arguments size to target block parameters size mismatch", b, op);
154             } else {
155                 Block tb = r.targetBlock();
156                 for (int i = 0; i < args.size(); i++) {
157                     if (!isAssignable(params.get(i).type(), args.get(i), tb, b)) {
158                         error("%s %s %s is not assignable from %s", op.parentBlock(), op, params.get(i).type(), args.get(i).type());
159                     }
160                 }
161             }
162         }
163     }
164 
165     private boolean isAssignable(TypeElement toType, Value fromValue,  Object toContext, Object fromContext) {
166         if (toType.equals(fromValue.type())) return true;
167         var to = resolveToClass(toType, toContext);
168         var from = resolveToClass(fromValue.type(), fromContext);
169         if (from.isPrimitive()) {
170             // Primitive types assignability
171             return to == int.class && (from == byte.class || from == short.class || from == char.class);
172         } else {
173             // Objects assignability
174             return to.isAssignableFrom(from);
175         }
176     }
177 
178     public Class<?> resolveToClass(TypeElement d, Object context) {
179         try {
180             if (d instanceof JavaType jt) {
181                 return (Class<?>)jt.erasure().resolve(lookup);
182             } else {
183                 error("%s %s is not a Java type", context, d);
184             }
185         } catch (ReflectiveOperationException e) {
186             error("%s %s", context, e.getMessage());
187         }
188         return Object.class;
189     }
190 
191     private void verifyOpHandleExists(Op op, String opName) {
192         try {
193             var mt = Interpreter.resolveToMethodType(lookup, op.opType()).erase();
194             MethodHandles.lookup().findStatic(InvokableLeafOps.class, opName, mt);
195         } catch (NoSuchMethodException nsme) {
196             error("%s %s of type %s is not supported", op.parentBlock(), op, op.opType());
197         } catch (IllegalAccessException iae) {
198             error("%s %s %s",  op.parentBlock(), op, iae.getMessage());
199         }
200     }
201 
202     private void verifyExceptionRegions() {
203         rootOp.traverse(new HashMap<Block, List<Block>>(), CodeElement.blockVisitor((map, b) -> {
204             List<Block> catchBlocks = map.computeIfAbsent(b, _ -> List.of());
205             switch (b.terminatingOp()) {
206                 case CoreOp.BranchOp br ->
207                     verifyCatchStack(b, br, br.branch(), catchBlocks, map);
208                 case CoreOp.ConditionalBranchOp cbr -> {
209                     verifyCatchStack(b, cbr, cbr.trueBranch(), catchBlocks, map);
210                     verifyCatchStack(b, cbr, cbr.falseBranch(), catchBlocks, map);
211                 }
212                 case CoreOp.ExceptionRegionEnter ere -> {
213                     List<Block> newCatchBlocks = new ArrayList<>();
214                     newCatchBlocks.addAll(catchBlocks);
215                     for (Block.Reference cb : ere.catchBlocks()) {
216                         newCatchBlocks.add(cb.targetBlock());
217                         verifyCatchStack(b, ere, cb, catchBlocks, map);
218                     }
219                     verifyCatchStack(b, ere, ere.start(), newCatchBlocks, map);
220                 }
221                 case CoreOp.ExceptionRegionExit ere -> {
222                     List<Block> exitedCatchBlocks = ere.catchBlocks().stream().map(Block.Reference::targetBlock).toList();
223                     if (exitedCatchBlocks.size() > catchBlocks.size() || !catchBlocks.reversed().subList(0, exitedCatchBlocks.size()).equals(exitedCatchBlocks)) {
224                         error("%s %s exited catch blocks %s does not match actual stack %s", b, ere, exitedCatchBlocks, catchBlocks);
225                     } else {
226                         verifyCatchStack(b, ere, ere.end(), catchBlocks.subList(0, catchBlocks.size() - exitedCatchBlocks.size()), map);
227                     }
228                 }
229                 default -> {}
230             }
231             return map;
232         }));
233     }
234 
235     private void verifyCatchStack(Block b, Op op, Block.Reference target, List<Block> catchBlocks, Map<Block, List<Block>> blockMap) {
236         blockMap.compute(target.targetBlock(), (tb, stored) -> {
237             if (stored != null && !stored.equals(catchBlocks)) {
238                 error("%s %s catch stack mismatch at target %s %s vs %s", b, op, tb, stored, catchBlocks);
239             }
240             return catchBlocks;
241         });
242     }
243 }