1 import jdk.incubator.code.*;
  2 import jdk.incubator.code.bytecode.impl.LoweringTransform;
  3 import jdk.incubator.code.dialect.core.CoreOp;
  4 import jdk.incubator.code.dialect.core.CoreType;
  5 import jdk.incubator.code.dialect.java.MethodRef;
  6 import jdk.incubator.code.interpreter.Interpreter;
  7 import org.junit.jupiter.api.Assertions;
  8 import org.junit.jupiter.api.Test;
  9 import org.junit.jupiter.params.ParameterizedTest;
 10 import org.junit.jupiter.params.provider.Arguments;
 11 import org.junit.jupiter.params.provider.MethodSource;
 12 
 13 import java.lang.invoke.MethodHandles;
 14 import java.lang.reflect.Method;
 15 import java.util.*;
 16 import java.util.stream.Stream;
 17 
 18 import static jdk.incubator.code.dialect.java.JavaOp.*;
 19 
 20 /*
 21  * @test
 22  * @modules jdk.incubator.code/jdk.incubator.code.bytecode.impl
 23  * @enablePreview
 24  * @compile TestIsCaseConstantSwitch.java
 25  * @run junit TestIsCaseConstantSwitch
 26  */
 27 public class TestIsCaseConstantSwitch {
 28 
 29     class C {
 30         static final int x = 26;
 31     }
 32 
 33     @Reflect
 34     private static void caseConstantSwitchExpressions() {
 35         // switch label
 36         // case label
 37         // list of case constant
 38         // every case constant must be either a constant expression or the name of an enum constant
 39         // null literal
 40         // list of case patterns
 41         // default label
 42         final int fv = 25;
 43         int i = -1;
 44         String r = switch (i) {
 45             // literal of primitive type
 46             case 1 -> "A";
 47             // unary operators +, -, ~
 48             case +2 -> "B";
 49             case -2 -> "BB";
 50             case ~2 -> "BBB"; // -3
 51             // multiplicative operators *, /, %
 52             case 3 * 4 -> "E";
 53             case 3 / 4 -> "EE";
 54             case 3 % 4 -> "EEE";
 55             // shift operators <<, >>, >>>
 56             case 4 << 5 -> "F"; // 128
 57             case 10 >> 1 -> "FF"; // 5
 58             case 8 >>> 1 -> "FFF"; // 4
 59             // relational operators <, <=, >, >= (and conditional operator)
 60             case 1 < 2 ? 9 : 10 -> "G"; // 9
 61             case 1 <= 2 ? 11 : 12 -> "GG"; // 11
 62             case 1 > 2 ? 13 : 14 -> "GGG"; // 14
 63             case 1 >= 2 ? 15 : 16 -> "GGGG"; // 16
 64             // equality operators ==, !=
 65             case 1 == 2 ? 17 : 18 -> "H"; // 18
 66             case 1 != 2 ? 19 : 20 -> "HH"; // 19
 67             // bitwise and logical operators &, ^, |
 68             case 6 & 6 -> "I"; // 6
 69             case 7 ^ 8 -> "II"; // 15
 70             case 8 | 10 -> "III"; // 10
 71             // conditional-and operator &&
 72             case 2 > 3 && 5 > 6 ? 21 : 22 -> "J"; // 22
 73             case 2 > 3 || 5 > 6 ? 23 : 24 -> "JJ"; // 24
 74             // parenthesized expressions whose contained expression is a constant expression
 75             case (20) -> "K";
 76             // simple names that refer to constant variables
 77             case fv -> "L";
 78             // qualified names of the form TypeName.Identifier that refer to constant variables
 79             case C.x -> "M";
 80             // list of case constants
 81             case 21, 30 -> null;
 82             // casts
 83             case (int) 31L -> "N";
 84             case (int) 34f -> "NN";
 85             // default
 86             default -> "X";
 87         };
 88 
 89         // we can have a target of type Byte, Short, Character, Integer
 90         // as long as we don't introduce case null, javac will generate labels identical to what we have in source code
 91         Integer ii = -2;
 92         r = switch (ii) {
 93             case 1 -> "A";
 94             default -> "X";
 95         };
 96 
 97         char c = '2';
 98         r = switch (c) {
 99             case '1' -> "1";
100             default -> "";
101         };
102     }
103 
104     enum E {
105         V;
106     }
107 
108     @Reflect
109     static void nonCaseConstantSwitchExpressions() {
110         int r;
111 
112         String s = "";
113         r = switch (s) {
114             case "A" -> 1;
115             default -> 0;
116         };
117 
118         E e = E.V;
119         r = switch (e) {
120             case V -> 1;
121         };
122 
123         boolean b = false;
124         r = switch (b) {
125             case true -> 1;
126             default -> 0;
127         };
128 
129         long l = 5L;
130         r = switch (l) {
131             case 1L -> 1;
132             default -> 0;
133         };
134 
135         float f = 5f;
136         r = switch (f) {
137             case 1f -> 1;
138             default -> 0;
139         };
140 
141         double d = 5d;
142         r = switch (d) {
143             case 1d -> 1;
144             default -> 0;
145         };
146 
147         Integer i = 4;
148         r = switch (i) {
149             case 1 -> 1;
150             case null -> -1;
151             default -> 0;
152         };
153     }
154 
155     static Stream<Arguments> cases() {
156         return Stream.of(
157                 Arguments.of("caseConstantSwitchExpressions", true),
158                 Arguments.of("nonCaseConstantSwitchExpressions", false)
159         );
160     }
161 
162     @ParameterizedTest
163     @MethodSource("cases")
164     void testIsConstantLabelSwitch(String methodName, boolean expected) throws NoSuchMethodException {
165         Method m = this.getClass().getDeclaredMethod(methodName);
166         CoreOp.FuncOp codeModel = Op.ofMethod(m).get();
167         List<SwitchExpressionOp> swExprOps = codeModel.body().entryBlock().ops().stream()
168                 .filter(o -> o instanceof SwitchExpressionOp)
169                 .map(o -> ((SwitchExpressionOp) o)).toList();
170         for (SwitchExpressionOp swExprOp : swExprOps) {
171             Assertions.assertEquals(
172                     new LoweringTransform.ConstantLabelSwitchChecker(swExprOp, MethodHandles.lookup()).isCaseConstantSwitch(),
173                     expected,
174                     swExprOp.toText());
175         }
176     }
177 
178     @Test
179     void testGettingLabels() throws NoSuchMethodException {
180         var expectedLabels = List.of(1, +2, -2, ~2, 12, 3 / 4, 3 % 4, 4 << 5, 10 >> 1,
181                 8 >>> 1, 1 < 2 ? 9 : 10, 1 <= 2 ? 11 : 12, 1 > 2 ? 13 : 14, 1 >= 2 ? 15 : 16, 1 == 2 ? 17 : 18,
182                 1 != 2 ? 19 : 20, 6 & 6, 7 ^ 8, 8 | 10, 2 > 3 && 5 > 6 ? 21 : 22, 2 > 3 || 5 > 6 ? 23 : 24, (20), 25,
183                 C.x, 21, 30, (int) 31L, (int) 34f );
184         var funcOp = Op.ofMethod(this.getClass().getDeclaredMethod("caseConstantSwitchExpressions")).get();
185         System.out.println(funcOp.toText());
186         var swOp = (JavaSwitchOp) funcOp.body().entryBlock().ops().stream().filter(op -> op instanceof JavaSwitchOp).findFirst().get();
187         List<Integer> actualLabels = getLabelsAndTargets(MethodHandles.lookup(), swOp);
188         System.out.println(actualLabels);
189         Assertions.assertEquals(expectedLabels, actualLabels);
190     }
191 
192     static ArrayList<Integer> getLabelsAndTargets(MethodHandles.Lookup lookup, JavaSwitchOp swOp) {
193         var labels = new ArrayList<Integer>();
194         for (int i = 0; i < swOp.bodies().size(); i += 2) {
195             Body labelBody = swOp.bodies().get(i);
196             labels.addAll(getLabels(lookup, labelBody));
197         }
198         return labels;
199     }
200 
201     static List<Integer> getLabels(MethodHandles.Lookup lookup, Body body) {
202         if (body.blocks().size() != 1 || !(body.entryBlock().terminatingOp() instanceof CoreOp.YieldOp yop) ||
203                 !(yop.yieldValue() instanceof Result opr)) {
204             throw new IllegalStateException("Body of a java switch fails the expected structure");
205         }
206         var labels = new ArrayList<Integer>();
207         if (opr.op() instanceof EqOp eqOp) {
208             labels.add(extractConstantLabel(lookup, body, eqOp));
209         } else if (opr.op() instanceof InvokeOp invokeOp &&
210                 invokeOp.invokeDescriptor().equals(MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class))) {
211             labels.add(extractConstantLabel(lookup, body, invokeOp));
212         } else if (opr.op() instanceof ConditionalOrOp cor) {
213             for (Body corbody : cor.bodies()) {
214                 labels.addAll(getLabels(lookup, corbody));
215             }
216         } else if (!(opr.op() instanceof CoreOp.ConstantOp)){ // not default label
217             throw new IllegalStateException();
218         }
219         return labels;
220     }
221 
222     static Integer extractConstantLabel(MethodHandles.Lookup lookup, Body body, Op whenToStop) {
223         Op lastOp = body.entryBlock().ops().get(body.entryBlock().ops().indexOf(whenToStop) - 1);
224         CoreOp.FuncOp funcOp = CoreOp.func("f", CoreType.functionType(lastOp.result().type())).body(block -> {
225             // in case we refer to constant variables in the label
226             for (Value capturedValue : body.capturedValues()) {
227                 if (!(capturedValue instanceof Result r) || !(r.op() instanceof CoreOp.VarOp vop)) {
228                     continue;
229                 }
230                 block.op(((Result) vop.initOperand()).op());
231                 block.op(vop);
232             }
233             Result last = null;
234             for (Op op : body.entryBlock().ops()) {
235                 if (op.equals(whenToStop)) {
236                     break;
237                 }
238                 last = block.op(op);
239             }
240             block.op(CoreOp.return_(last));
241         });
242         Object res = Interpreter.invoke(lookup, funcOp.transform(CodeTransformer.LOWERING_TRANSFORMER));
243         return switch (res) {
244             case Byte b -> Integer.valueOf(b);
245             case Short s -> Integer.valueOf(s);
246             case Character c -> Integer.valueOf(c);
247             case Integer i -> i;
248             default -> throw new IllegalStateException(); // @@@ not going to happen
249         };
250     }
251 }