1 import jdk.incubator.code.Reflect;
  2 import jdk.incubator.code.CodeTransformer;
  3 import jdk.incubator.code.dialect.core.CoreOp;
  4 import jdk.incubator.code.extern.OpWriter;
  5 import jdk.incubator.code.interpreter.Interpreter;
  6 import org.junit.jupiter.api.Assertions;
  7 import org.junit.jupiter.api.Test;
  8 
  9 import java.io.IOException;
 10 import java.io.OutputStream;
 11 import java.io.StringWriter;
 12 import java.lang.invoke.MethodHandles;
 13 import java.lang.reflect.Method;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18  * @test
 19  * @modules jdk.incubator.code
 20  * @run junit TestSwitchExpressionOp
 21  * @run main Unreflect TestSwitchExpressionOp
 22  * @run junit TestSwitchExpressionOp
 23  */
 24 public class TestSwitchExpressionOp {
 25 
 26     @Test
 27     void testCasePatternGuard() {
 28         CoreOp.FuncOp lmodel = lower("casePatternGuard");
 29         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
 30         for (Object arg : args) {
 31             Assertions.assertEquals(casePatternGuard(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
 32         }
 33     }
 34     @Reflect
 35     static String casePatternGuard(Object obj) {
 36         return switch (obj) {
 37             case String s when s.length() > 3 -> "str with length > %d".formatted(s.length());
 38             case R(Number n) when n.getClass().equals(Double.class) -> "R(Double)";
 39             default -> "else";
 40         };
 41     }
 42 
 43     @Test
 44     void testCaseRecordPattern() {
 45         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
 46         // @@@ test with generic record
 47         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
 48         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
 49         for (Object arg : args) {
 50             Assertions.assertEquals(caseRecordPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
 51         }
 52     }
 53     record R(Number n) {}
 54     @Reflect
 55     static String caseRecordPattern(Object o) {
 56         return switch (o) {
 57             case R(Number _) -> "R(_)";
 58             default -> "else";
 59         };
 60     }
 61     @Test
 62     void testCaseTypePattern() {
 63         CoreOp.FuncOp lmodel = lower("caseTypePattern");
 64         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
 65         for (Object arg : args) {
 66             Assertions.assertEquals(caseTypePattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
 67         }
 68     }
 69     @Reflect
 70     static String caseTypePattern(Object o) {
 71         return switch (o) {
 72             case String _ -> "String"; // class
 73             case RandomAccess _ -> "RandomAccess"; // interface
 74             case int[] _ -> "int[]"; // array primitive
 75             case Stack[][] _ -> "Stack[][]"; // array class
 76             case Collection[][][] _ -> "Collection[][][]"; // array interface
 77             case final Number n -> "Number"; // final modifier
 78             default -> "something else";
 79         };
 80     }
 81 
 82     @Test
 83     void testCasePatternWithCaseConstant() {
 84         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
 85         int[] args = {42, 43, -44, 0};
 86         for (int arg : args) {
 87             Assertions.assertEquals(casePatternWithCaseConstant(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
 88         }
 89     }
 90 
 91     @Reflect
 92     static String casePatternWithCaseConstant(Integer a) {
 93         return switch (a) {
 94             case 42 -> "forty two";
 95             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
 96             case Integer i when i > 0 -> "positive int";
 97             case Integer i when i < 0 -> "negative int";
 98             default -> "zero";
 99         };
100     }
101 
102     // @Test
103     void testCasePatternMultiLabel() {
104         CoreOp.FuncOp lmodel = lower("casePatternMultiLabel");
105         Object[] args = {(byte) 1, (short) 2, 'A', 3, 4L, 5f, 6d, true, "str"};
106         for (Object arg : args) {
107             Assertions.assertEquals(casePatternMultiLabel(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
108         }
109     }
110     // @Reflect
111     // code model for such as code is not supported
112     // @@@ support this case and uncomment its test
113     private static String casePatternMultiLabel(Object o) {
114         return switch (o) {
115             case Integer _, Long _, Character _, Byte _, Short _-> "integral type";
116             default -> "non integral type";
117         };
118     }
119 
120     @Test
121     void testCasePatternThrow() {
122         CoreOp.FuncOp lmodel = lower("casePatternThrow");
123 
124         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
125         for (Object arg : args) {
126             Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
127         }
128 
129         Object[] args2 = {"abc", List.of()};
130         for (Object arg : args2) {
131             Assertions.assertEquals(casePatternThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
132         }
133     }
134 
135     @Reflect
136     private static String casePatternThrow(Object o) {
137         return switch (o) {
138             case Number n -> throw new IllegalArgumentException();
139             case String s -> "a string";
140             default -> o.getClass().getName();
141         };
142     }
143 
144     @Test
145     void testCasePatternBehaviorIsSyntaxIndependent() {
146         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
147         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
148         CoreOp.FuncOp statement = lower("casePatternStatement");
149 
150         Object[] args = {1, "2", 3L};
151 
152         for (Object arg : args) {
153             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
154             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
155         }
156     }
157 
158     @Reflect
159     private static String casePatternRuleExpression(Object o) {
160         return switch (o) {
161             case Integer i -> "integer";
162             case String s -> "string";
163             default -> "not integer nor string";
164         };
165     }
166 
167     @Reflect
168     private static String casePatternRuleBlock(Object o) {
169         return switch (o) {
170             case Integer i -> {
171                 yield "integer";
172             }
173             case String s -> {
174                 yield "string";
175             }
176             default -> {
177                 yield "not integer nor string";
178             }
179         };
180     }
181 
182     @Reflect
183     private static String casePatternStatement(Object o) {
184         return switch (o) {
185             case Integer i: yield "integer";
186             case String s: yield "string";
187             default: yield "not integer nor string";
188         };
189     }
190 
191     @Test
192     void testCaseConstantOtherKindsOfExpr() {
193         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
194         for (int i = 0; i < 14; i++) {
195             Assertions.assertEquals(caseConstantOtherKindsOfExpr(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
196         }
197     }
198 
199     static class Constants {
200         static final int c1 = 12;
201     }
202 
203     @Reflect
204     private static String caseConstantOtherKindsOfExpr(int i) {
205         final int eleven = 11;
206         return switch (i) {
207             case 1 & 0xF -> "1";
208             case 4>>1 -> "2";
209             case (int) 3L -> "3";
210             case 2<<1 -> "4";
211             case 10 / 2 -> "5";
212             case 12 - 6 -> "6";
213             case 3 + 4 -> "7";
214             case 2 * 2 * 2 -> "8";
215             case 8 | 1 -> "9";
216             case (10) -> "10";
217             case eleven -> "11";
218             case Constants.c1 -> String.valueOf(Constants.c1);
219             case 1 > 0 ? 13 : 133 -> "13";
220             default -> "an int";
221         };
222     }
223 
224     @Test
225     void testCaseConstantEnum() {
226         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
227         for (Day day : Day.values()) {
228             Assertions.assertEquals(caseConstantEnum(day), Interpreter.invoke(MethodHandles.lookup(), lmodel, day));
229         }
230     }
231 
232     enum Day {
233         MON, TUE, WED, THU, FRI, SAT, SUN
234     }
235 
236     @Reflect
237     private static int caseConstantEnum(Day d) {
238         return switch (d) {
239             case MON, FRI, SUN -> 6;
240             case TUE -> 7;
241             case THU, SAT -> 8;
242             case WED -> 9;
243         };
244     }
245 
246     @Test
247     void testCaseConstantFallThrough() {
248         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
249         char[] args = {'A', 'B', 'C'};
250         for (char arg : args) {
251             Assertions.assertEquals(caseConstantFallThrough(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
252         }
253     }
254 
255     @Reflect
256     private static String caseConstantFallThrough(char c) {
257         return switch (c) {
258             case 'A':
259             case 'B':
260                 yield "A or B";
261             default:
262                 yield "Neither A nor B";
263         };
264     }
265 
266     // @Reflect
267     // compiler code doesn't support case null, default
268     // @@@ support such as case and test the switch expression lowering for this case
269     private static String caseConstantNullAndDefault(String s) {
270         return switch (s) {
271             case "abc" -> "alphabet";
272             case null, default -> "null or default";
273         };
274     }
275 
276     @Test
277     void testCaseConstantNullLabel() {
278         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
279         String[] args = {null, "non null"};
280         for (String arg : args) {
281             Assertions.assertEquals(caseConstantNullLabel(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
282         }
283     }
284 
285     @Reflect
286     private static String caseConstantNullLabel(String s) {
287         return switch (s) {
288             case null -> "null";
289             default -> "non null";
290         };
291     }
292 
293     @Test
294     void testCaseConstantThrow() {
295         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
296         Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, 8));
297         int[] args = {9, 10};
298         for (int arg : args) {
299             Assertions.assertEquals(caseConstantThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
300         }
301     }
302 
303     @Reflect
304     private static String caseConstantThrow(Integer i) {
305         return switch (i) {
306             case 8 -> throw new IllegalArgumentException();
307             case 9 -> "NINE";
308             default -> "An integer";
309         };
310     }
311 
312     @Test
313     void testCaseConstantMultiLabels() {
314         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
315         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
316         for (char arg : args) {
317             Assertions.assertEquals(caseConstantMultiLabels(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
318         }
319     }
320 
321     @Reflect
322     private static String caseConstantMultiLabels(char c) {
323         return switch (Character.toLowerCase(c)) {
324             case 'a', 'e', 'i', 'o', 'u': yield "vowel";
325             default: yield "consonant";
326         };
327     }
328 
329     @Test
330     void testCaseConstantBehaviorIsSyntaxIndependent() {
331         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
332         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
333         CoreOp.FuncOp statement = lower("caseConstantStatement");
334 
335         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
336 
337         for (String arg : args) {
338             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
339             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
340         }
341     }
342 
343     @Reflect
344     public static String caseConstantRuleExpression(String r) {
345         return switch (r) {
346             case "FOO" -> "BAR";
347             case "BAR" -> "BAZ";
348             case "BAZ" -> "FOO";
349             default -> "";
350         };
351     }
352 
353     @Reflect
354     public static String caseConstantRuleBlock(String r) {
355         return switch (r) {
356             case "FOO" -> {
357                 yield "BAR";
358             }
359             case "BAR" -> {
360                 yield "BAZ";
361             }
362             case "BAZ" -> {
363                 yield "FOO";
364             }
365             default -> {
366                 yield "";
367             }
368         };
369     }
370 
371     @Reflect
372     private static String caseConstantStatement(String s) {
373         return switch (s) {
374             case "FOO": yield "BAR";
375             case "BAR": yield "BAZ";
376             case "BAZ": yield "FOO";
377             default: yield "";
378         };
379     }
380 
381     @Test
382     void testCaseConstantConv() {
383         CoreOp.FuncOp lmodel = lower("caseConstantConv");
384         short[] args = {1, 2, 3, 4};
385         for (short arg : args) {
386             Assertions.assertEquals(caseConstantConv(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
387         }
388     }
389 
390     @Reflect
391     static String caseConstantConv(short a) {
392         final short s = 1;
393         final byte b = 2;
394         return switch (a) {
395             case s -> "one"; // identity
396             case b -> "three"; // widening primitive conversion
397             case 3 -> "two"; // narrowing primitive conversion
398             default -> "default";
399         };
400     }
401 
402     @Test
403     void testCaseConstantConv2() {
404         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
405         Byte[] args = {1, 2, 3};
406         for (Byte arg : args) {
407             Assertions.assertEquals(caseConstantConv2(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
408         }
409     }
410 
411     @Reflect
412     static String caseConstantConv2(Byte a) {
413         final byte b = 2;
414         return switch (a) {
415             case 1 -> "one"; // narrowing primitive conversion followed by a boxing conversion
416             case b -> "two"; // boxing
417             default -> "default";
418         };
419     }
420 
421     @Test
422     void testUnconditionalPattern() {
423         CoreOp.FuncOp lmodel = lower("unconditionalPattern");
424         String[] args = {"A", "X"};
425         for (String arg : args) {
426             Assertions.assertEquals(unconditionalPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
427         }
428     }
429 
430     @Reflect
431     static String unconditionalPattern(String s) {
432         return switch (s) {
433             case "A" -> "A";
434             case Object o -> "default";
435         };
436     }
437 
438 
439     @Test
440     void testDefaultCaseNotTheLast() {
441         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
442         String[] args = {"something", "M", "A"};
443         for (String arg : args) {
444             Assertions.assertEquals(defaultCaseNotTheLast(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
445         }
446     }
447 
448     @Reflect
449     static String defaultCaseNotTheLast(String s) {
450         return switch (s) {
451             default -> "else";
452             case "M" -> "Mow";
453             case "A" -> "Aow";
454         };
455     }
456 
457     // we are not testing switch expr that has no default,
458     // because to test for MatchException we need to set up separate compilation
459     // in compiler tests we are checking that the code model contains a default case that throws MatchException
460     // that should be enough
461 
462     private static CoreOp.FuncOp lower(String methodName) {
463         return lower(getCodeModel(methodName));
464     }
465 
466     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
467         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
468 
469         CoreOp.FuncOp lf = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
470         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
471 
472         return lf;
473     }
474 
475     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
476         StringWriter sw = new StringWriter();
477         new OpWriter(sw, options).writeOp(f);
478         try {
479             os.write(sw.toString().getBytes());
480         } catch (IOException e) {
481             throw new RuntimeException(e);
482         }
483     }
484 
485     private static CoreOp.FuncOp getCodeModel(String methodName) {
486         Optional<Method> om = Stream.of(TestSwitchExpressionOp.class.getDeclaredMethods())
487                 .filter(m -> m.getName().equals(methodName))
488                 .findFirst();
489 
490         return CoreOp.ofMethod(om.get()).get();
491     }
492 }