1 import jdk.incubator.code.*;
2 import jdk.incubator.code.bytecode.impl.LoweringTransform;
3 import jdk.incubator.code.dialect.core.CoreOp;
4 import jdk.incubator.code.dialect.core.CoreType;
5 import jdk.incubator.code.dialect.java.MethodRef;
6 import jdk.incubator.code.interpreter.Interpreter;
7 import org.junit.jupiter.api.Assertions;
8 import org.junit.jupiter.api.Test;
9 import org.junit.jupiter.params.ParameterizedTest;
10 import org.junit.jupiter.params.provider.Arguments;
11 import org.junit.jupiter.params.provider.MethodSource;
12
13 import java.lang.invoke.MethodHandles;
14 import java.lang.reflect.Method;
15 import java.util.*;
16 import java.util.stream.Stream;
17
18 import static jdk.incubator.code.dialect.java.JavaOp.*;
19
20 /*
21 * @test
22 * @modules jdk.incubator.code/jdk.incubator.code.bytecode.impl
23 * @enablePreview
24 * @compile TestIsCaseConstantSwitch.java
25 * @run junit TestIsCaseConstantSwitch
26 */
27 public class TestIsCaseConstantSwitch {
28
29 class C {
30 static final int x = 26;
31 }
32
33 @Reflect
34 private static void caseConstantSwitchExpressions() {
35 // switch label
36 // case label
37 // list of case constant
38 // every case constant must be either a constant expression or the name of an enum constant
39 // null literal
40 // list of case patterns
41 // default label
42 final int fv = 25;
43 int i = -1;
44 String r = switch (i) {
45 // literal of primitive type
46 case 1 -> "A";
47 // unary operators +, -, ~
48 case +2 -> "B";
49 case -2 -> "BB";
50 case ~2 -> "BBB"; // -3
51 // multiplicative operators *, /, %
52 case 3 * 4 -> "E";
53 case 3 / 4 -> "EE";
54 case 3 % 4 -> "EEE";
55 // shift operators <<, >>, >>>
56 case 4 << 5 -> "F"; // 128
57 case 10 >> 1 -> "FF"; // 5
58 case 8 >>> 1 -> "FFF"; // 4
59 // relational operators <, <=, >, >= (and conditional operator)
60 case 1 < 2 ? 9 : 10 -> "G"; // 9
61 case 1 <= 2 ? 11 : 12 -> "GG"; // 11
62 case 1 > 2 ? 13 : 14 -> "GGG"; // 14
63 case 1 >= 2 ? 15 : 16 -> "GGGG"; // 16
64 // equality operators ==, !=
65 case 1 == 2 ? 17 : 18 -> "H"; // 18
66 case 1 != 2 ? 19 : 20 -> "HH"; // 19
67 // bitwise and logical operators &, ^, |
68 case 6 & 6 -> "I"; // 6
69 case 7 ^ 8 -> "II"; // 15
70 case 8 | 10 -> "III"; // 10
71 // conditional-and operator &&
72 case 2 > 3 && 5 > 6 ? 21 : 22 -> "J"; // 22
73 case 2 > 3 || 5 > 6 ? 23 : 24 -> "JJ"; // 24
74 // parenthesized expressions whose contained expression is a constant expression
75 case (20) -> "K";
76 // simple names that refer to constant variables
77 case fv -> "L";
78 // qualified names of the form TypeName.Identifier that refer to constant variables
79 case C.x -> "M";
80 // list of case constants
81 case 21, 30 -> null;
82 // casts
83 case (int) 31L -> "N";
84 case (int) 34f -> "NN";
85 // default
86 default -> "X";
87 };
88
89 // we can have a target of type Byte, Short, Character, Integer
90 // as long as we don't introduce case null, javac will generate labels identical to what we have in source code
91 Integer ii = -2;
92 r = switch (ii) {
93 case 1 -> "A";
94 default -> "X";
95 };
96
97 char c = '2';
98 r = switch (c) {
99 case '1' -> "1";
100 default -> "";
101 };
102 }
103
104 enum E {
105 V;
106 }
107
108 @Reflect
109 static void nonCaseConstantSwitchExpressions() {
110 int r;
111
112 String s = "";
113 r = switch (s) {
114 case "A" -> 1;
115 default -> 0;
116 };
117
118 E e = E.V;
119 r = switch (e) {
120 case V -> 1;
121 };
122
123 boolean b = false;
124 r = switch (b) {
125 case true -> 1;
126 default -> 0;
127 };
128
129 long l = 5L;
130 r = switch (l) {
131 case 1L -> 1;
132 default -> 0;
133 };
134
135 float f = 5f;
136 r = switch (f) {
137 case 1f -> 1;
138 default -> 0;
139 };
140
141 double d = 5d;
142 r = switch (d) {
143 case 1d -> 1;
144 default -> 0;
145 };
146
147 Integer i = 4;
148 r = switch (i) {
149 case 1 -> 1;
150 case null -> -1;
151 default -> 0;
152 };
153 }
154
155 static Stream<Arguments> cases() {
156 return Stream.of(
157 Arguments.of("caseConstantSwitchExpressions", true),
158 Arguments.of("nonCaseConstantSwitchExpressions", false)
159 );
160 }
161
162 @ParameterizedTest
163 @MethodSource("cases")
164 void testIsConstantLabelSwitch(String methodName, boolean expected) throws NoSuchMethodException {
165 Method m = this.getClass().getDeclaredMethod(methodName);
166 CoreOp.FuncOp codeModel = Op.ofMethod(m).get();
167 List<SwitchExpressionOp> swExprOps = codeModel.body().entryBlock().ops().stream()
168 .filter(o -> o instanceof SwitchExpressionOp)
169 .map(o -> ((SwitchExpressionOp) o)).toList();
170 for (SwitchExpressionOp swExprOp : swExprOps) {
171 Assertions.assertEquals(
172 new LoweringTransform.ConstantLabelSwitchChecker(swExprOp, MethodHandles.lookup()).isCaseConstantSwitch(),
173 expected,
174 swExprOp.toText());
175 }
176 }
177
178 @Test
179 void testGettingLabels() throws NoSuchMethodException {
180 var expectedLabels = List.of(1, +2, -2, ~2, 12, 3 / 4, 3 % 4, 4 << 5, 10 >> 1,
181 8 >>> 1, 1 < 2 ? 9 : 10, 1 <= 2 ? 11 : 12, 1 > 2 ? 13 : 14, 1 >= 2 ? 15 : 16, 1 == 2 ? 17 : 18,
182 1 != 2 ? 19 : 20, 6 & 6, 7 ^ 8, 8 | 10, 2 > 3 && 5 > 6 ? 21 : 22, 2 > 3 || 5 > 6 ? 23 : 24, (20), 25,
183 C.x, 21, 30, (int) 31L, (int) 34f );
184 var funcOp = Op.ofMethod(this.getClass().getDeclaredMethod("caseConstantSwitchExpressions")).get();
185 System.out.println(funcOp.toText());
186 var swOp = (JavaSwitchOp) funcOp.body().entryBlock().ops().stream().filter(op -> op instanceof JavaSwitchOp).findFirst().get();
187 List<Integer> actualLabels = getLabelsAndTargets(MethodHandles.lookup(), swOp);
188 System.out.println(actualLabels);
189 Assertions.assertEquals(expectedLabels, actualLabels);
190 }
191
192 static ArrayList<Integer> getLabelsAndTargets(MethodHandles.Lookup lookup, JavaSwitchOp swOp) {
193 var labels = new ArrayList<Integer>();
194 for (int i = 0; i < swOp.bodies().size(); i += 2) {
195 Body labelBody = swOp.bodies().get(i);
196 labels.addAll(getLabels(lookup, labelBody));
197 }
198 return labels;
199 }
200
201 static List<Integer> getLabels(MethodHandles.Lookup lookup, Body body) {
202 if (body.blocks().size() != 1 || !(body.entryBlock().terminatingOp() instanceof CoreOp.YieldOp yop) ||
203 !(yop.yieldValue() instanceof Result opr)) {
204 throw new IllegalStateException("Body of a java switch fails the expected structure");
205 }
206 var labels = new ArrayList<Integer>();
207 if (opr.op() instanceof EqOp eqOp) {
208 labels.add(extractConstantLabel(lookup, body, eqOp));
209 } else if (opr.op() instanceof InvokeOp invokeOp &&
210 invokeOp.invokeDescriptor().equals(MethodRef.method(Objects.class, "equals", boolean.class, Object.class, Object.class))) {
211 labels.add(extractConstantLabel(lookup, body, invokeOp));
212 } else if (opr.op() instanceof ConditionalOrOp cor) {
213 for (Body corbody : cor.bodies()) {
214 labels.addAll(getLabels(lookup, corbody));
215 }
216 } else if (!(opr.op() instanceof CoreOp.ConstantOp)){ // not default label
217 throw new IllegalStateException();
218 }
219 return labels;
220 }
221
222 static Integer extractConstantLabel(MethodHandles.Lookup lookup, Body body, Op whenToStop) {
223 Op lastOp = body.entryBlock().ops().get(body.entryBlock().ops().indexOf(whenToStop) - 1);
224 CoreOp.FuncOp funcOp = CoreOp.func("f", CoreType.functionType(lastOp.result().type())).body(block -> {
225 // in case we refer to constant variables in the label
226 for (Value capturedValue : body.capturedValues()) {
227 if (!(capturedValue instanceof Result r) || !(r.op() instanceof CoreOp.VarOp vop)) {
228 continue;
229 }
230 block.op(((Result) vop.initOperand()).op());
231 block.op(vop);
232 }
233 Result last = null;
234 for (Op op : body.entryBlock().ops()) {
235 if (op.equals(whenToStop)) {
236 break;
237 }
238 last = block.op(op);
239 }
240 block.op(CoreOp.return_(last));
241 });
242 Object res = Interpreter.invoke(lookup, funcOp.transform(CodeTransformer.LOWERING_TRANSFORMER));
243 return switch (res) {
244 case Byte b -> Integer.valueOf(b);
245 case Short s -> Integer.valueOf(s);
246 case Character c -> Integer.valueOf(c);
247 case Integer i -> i;
248 default -> throw new IllegalStateException(); // @@@ not going to happen
249 };
250 }
251 }