1 import jdk.incubator.code.Reflect;
  2 import jdk.incubator.code.CodeTransformer;
  3 import jdk.incubator.code.dialect.core.CoreOp;
  4 import jdk.incubator.code.extern.OpWriter;
  5 import jdk.incubator.code.interpreter.Interpreter;
  6 import org.junit.jupiter.api.Assertions;
  7 import org.junit.jupiter.api.Test;
  8 
  9 import java.io.IOException;
 10 import java.io.OutputStream;
 11 import java.io.StringWriter;
 12 import java.lang.invoke.MethodHandles;
 13 import java.lang.reflect.Method;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18  * @test
 19  * @modules jdk.incubator.code
 20  * @run junit TestSwitchStatementOp
 21  * @run main Unreflect TestSwitchStatementOp
 22  * @run junit TestSwitchStatementOp
 23  *
 24  */
 25 public class TestSwitchStatementOp {
 26 
 27     @Test
 28     void testCaseConstantBehaviorIsSyntaxIndependent() {
 29         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
 30         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
 31         CoreOp.FuncOp statement = lower("caseConstantStatement");
 32 
 33         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
 34 
 35         for (String arg : args) {
 36             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
 37             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
 38         }
 39     }
 40 
 41     @Reflect
 42     public static String caseConstantRuleExpression(String r) {
 43         String s = "";
 44         switch (r) {
 45             case "FOO" -> s += "BAR";
 46             case "BAR" -> s += "BAZ";
 47             case "BAZ" -> s += "FOO";
 48             default -> s += "else";
 49         }
 50         return s;
 51     }
 52 
 53     @Reflect
 54     public static String caseConstantRuleBlock(String r) {
 55         String s = "";
 56         switch (r) {
 57             case "FOO" -> {
 58                 s += "BAR";
 59             }
 60             case "BAR" -> {
 61                 s += "BAZ";
 62             }
 63             case "BAZ" -> {
 64                 s += "FOO";
 65             }
 66             default -> {
 67                 s += "else";
 68             }
 69         }
 70         return s;
 71     }
 72 
 73     @Reflect
 74     private static String caseConstantStatement(String s) {
 75         String r = "";
 76         switch (s) {
 77             case "FOO":
 78                 r += "BAR";
 79                 break;
 80             case "BAR":
 81                 r += "BAZ";
 82                 break;
 83             case "BAZ":
 84                 r += "FOO";
 85                 break;
 86             default:
 87                 r += "else";
 88         }
 89         return r;
 90     }
 91 
 92     @Test
 93     void testCaseConstantMultiLabels() {
 94         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
 95         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
 96         for (char arg : args) {
 97             Assertions.assertEquals(caseConstantMultiLabels(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
 98         }
 99     }
100 
101     @Reflect
102     private static String caseConstantMultiLabels(char c) {
103         String r = "";
104         switch (Character.toLowerCase(c)) {
105             case 'a', 'e', 'i', 'o', 'u':
106                 r += "vowel";
107                 break;
108             default:
109                 r += "consonant";
110         }
111         return r;
112     }
113 
114     @Test
115     void testCaseConstantThrow() {
116         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
117 
118         Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, 8));
119 
120         int[] args = {9, 10};
121         for (int arg : args) {
122             Assertions.assertEquals(caseConstantThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
123         }
124     }
125 
126     @Reflect
127     private static String caseConstantThrow(Integer i) {
128         String r = "";
129         switch (i) {
130             case 8 -> throw new IllegalArgumentException();
131             case 9 -> r += "Nine";
132             default -> r += "An integer";
133         }
134         return r;
135     }
136 
137     @Test
138     void testCaseConstantNullLabel() {
139         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
140         String[] args = {null, "non null"};
141         for (String arg : args) {
142             Assertions.assertEquals(caseConstantNullLabel(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
143         }
144     }
145 
146     @Reflect
147     private static String caseConstantNullLabel(String s) {
148         String r = "";
149         switch (s) {
150             case null -> r += "null";
151             default -> r += "non null";
152         }
153         return r;
154     }
155 
156     @Test
157     void testCaseConstantFallThrough() {
158         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
159         char[] args = {'A', 'B', 'C'};
160         for (char arg : args) {
161             Assertions.assertEquals(caseConstantFallThrough(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
162         }
163     }
164 
165     @Reflect
166     private static String caseConstantFallThrough(char c) {
167         String r = "";
168         switch (c) {
169             case 'A':
170             case 'B':
171                 r += "A or B";
172                 break;
173             default:
174                 r += "Neither A nor B";
175         }
176         return r;
177     }
178 
179     @Test
180     void testCaseConstantEnum() {
181         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
182         for (Day day : Day.values()) {
183             Assertions.assertEquals(caseConstantEnum(day), Interpreter.invoke(MethodHandles.lookup(), lmodel, day));
184         }
185     }
186 
187     enum Day {
188         MON, TUE, WED, THU, FRI, SAT, SUN
189     }
190     @Reflect
191     private static String caseConstantEnum(Day d) {
192         String r = "";
193         switch (d) {
194             case MON, FRI, SUN -> r += 6;
195             case TUE -> r += 7;
196             case THU, SAT -> r += 8;
197             case WED -> r += 9;
198         }
199         return r;
200     }
201 
202     @Test
203     void testCaseConstantOtherKindsOfExpr() {
204         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
205         for (int i = 0; i < 14; i++) {
206             Assertions.assertEquals(caseConstantOtherKindsOfExpr(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
207         }
208     }
209 
210     static class Constants {
211         static final int c1 = 12;
212     }
213     @Reflect
214     private static String caseConstantOtherKindsOfExpr(int i) {
215         String r = "";
216         final int eleven = 11;
217         switch (i) {
218             case 1 & 0xF -> r += 1;
219             case 4>>1 -> r += "2";
220             case (int) 3L -> r += 3;
221             case 2<<1 -> r += 4;
222             case 10 / 2 -> r += 5;
223             case 12 - 6 -> r += 6;
224             case 3 + 4 -> r += 7;
225             case 2 * 2 * 2 -> r += 8;
226             case 8 | 1 -> r += 9;
227             case (10) -> r += 10;
228             case eleven -> r += 11;
229             case Constants.c1 -> r += Constants.c1;
230             case 1 > 0 ? 13 : 133 -> r += 13;
231             default -> r += "an int";
232         }
233         return r;
234     }
235 
236     @Test
237     void testCaseConstantConv() {
238         CoreOp.FuncOp lmodel = lower("caseConstantConv");
239         for (short i = 1; i < 5; i++) {
240             Assertions.assertEquals(caseConstantConv(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
241         }
242     }
243 
244     @Reflect
245     static String caseConstantConv(short a) {
246         final short s = 1;
247         final byte b = 2;
248         String r = "";
249         switch (a) {
250             case s -> r += "one"; // identity, short -> short
251             case b -> r += "two"; // widening primitive conversion, byte -> short
252             case 3 -> r += "three"; // narrowing primitive conversion, int -> short
253             default -> r += "else";
254         }
255         return r;
256     }
257 
258     @Test
259     void testCaseConstantConv2() {
260         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
261         Byte[] args = {1, 2, 3};
262         for (Byte arg : args) {
263             Assertions.assertEquals(caseConstantConv2(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
264         }
265     }
266 
267     @Reflect
268     static String caseConstantConv2(Byte a) {
269         final byte b = 2;
270         String r = "";
271         switch (a) {
272             case 1 -> r+= "one"; // narrowing primitive conversion followed by a boxing conversion, int -> bye -> Byte
273             case b -> r+= "two"; // boxing, byte -> Byte
274             default -> r+= "default";
275         }
276         return r;
277     }
278 
279     @Test
280     void testNonEnhancedSwStatNoDefault() {
281         CoreOp.FuncOp lmodel = lower("nonEnhancedSwStatNoDefault");
282         for (int i = 1; i < 4; i++) {
283             Assertions.assertEquals(nonEnhancedSwStatNoDefault(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
284         }
285     }
286 
287     @Reflect
288     static String nonEnhancedSwStatNoDefault(int a) {
289         String r = "";
290         switch (a) {
291             case 1 -> r += "1";
292             case 2 -> r += 2;
293         }
294         return r;
295     }
296 
297     // no reason to test enhanced switch statement that has no default
298     // because we can't test for MatchException without separate compilation
299 
300     @Test
301     void testEnhancedSwStatUnconditionalPattern() {
302         CoreOp.FuncOp lmodel = lower("enhancedSwStatUnconditionalPattern");
303         String[] args = {"A", "B"};
304         for (String arg : args) {
305             Assertions.assertEquals(enhancedSwStatUnconditionalPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
306         }
307     }
308 
309     @Reflect
310     static String enhancedSwStatUnconditionalPattern(String s) {
311         String r = "";
312         switch (s) {
313             case "A" -> r += "A";
314             case Object o -> r += "obj";
315         }
316         return r;
317     }
318 
319     @Test
320     void testCasePatternBehaviorIsSyntaxIndependent() {
321         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
322         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
323         CoreOp.FuncOp statement = lower("casePatternStatement");
324 
325         Object[] args = {1, "2", 3L};
326 
327         for (Object arg : args) {
328             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
329             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
330         }
331     }
332 
333     @Reflect
334     private static String casePatternRuleExpression(Object o) {
335         String r = "";
336         switch (o) {
337             case Integer i -> r += "integer";
338             case String s -> r+= "string";
339             default -> r+= "else";
340         }
341         return r;
342     }
343 
344     @Reflect
345     private static String casePatternRuleBlock(Object o) {
346         String r = "";
347         switch (o) {
348             case Integer i -> {
349                 r += "integer";
350             }
351             case String s -> {
352                 r += "string";
353             }
354             default -> {
355                 r += "else";
356             }
357         }
358         return r;
359     }
360 
361     @Reflect
362     private static String casePatternStatement(Object o) {
363         String r = "";
364         switch (o) {
365             case Integer i:
366                 r += "integer";
367                 break;
368             case String s:
369                 r += "string";
370                 break;
371             default:
372                 r += "else";
373         }
374         return r;
375     }
376 
377     @Test
378     void testCasePatternThrow() {
379         CoreOp.FuncOp lmodel = lower("casePatternThrow");
380 
381         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
382         for (Object arg : args) {
383             Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
384         }
385 
386         Object[] args2 = {"abc", List.of()};
387         for (Object arg : args2) {
388             Assertions.assertEquals(casePatternThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
389         }
390     }
391 
392     @Reflect
393     private static String casePatternThrow(Object o) {
394         String r = "";
395         switch (o) {
396             case Number n -> throw new IllegalArgumentException();
397             case String s -> r += "a string";
398             default -> r += o.getClass().getName();
399         }
400         return r;
401     }
402 
403     // @@@ when multi patterns is supported, we will test it
404 
405     @Test
406     void testCasePatternWithCaseConstant() {
407         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
408         int[] args = {42, 43, -44, 0};
409         for (int arg : args) {
410             Assertions.assertEquals(casePatternWithCaseConstant(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
411         }
412     }
413 
414     @Reflect
415     static String casePatternWithCaseConstant(Integer a) {
416         String r = "";
417         switch (a) {
418             case 42 -> r += "forty two";
419             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
420             case Integer i when i > 0 -> r += "positive int";
421             case Integer i when i < 0 -> r += "negative int";
422             default -> r += "zero";
423         }
424         return r;
425     }
426 
427     @Test
428     void testCaseTypePattern() {
429         CoreOp.FuncOp lmodel = lower("caseTypePattern");
430         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
431         for (Object arg : args) {
432             Assertions.assertEquals(caseTypePattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
433         }
434     }
435 
436     @Reflect
437     static String caseTypePattern(Object o) {
438         String r = "";
439         switch (o) {
440             case String _ -> r+= "String"; // class
441             case RandomAccess _ -> r+= "RandomAccess"; // interface
442             case int[] _ -> r+= "int[]"; // array primitive
443             case Stack[][] _ -> r+= "Stack[][]"; // array class
444             case Collection[][][] _ -> r+= "Collection[][][]"; // array interface
445             case final Number n -> r+= "Number"; // final modifier
446             default -> r+= "something else";
447         }
448         return r;
449     }
450 
451     @Test
452     void testCaseRecordPattern() {
453         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
454         // @@@ test with generic record
455         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
456         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
457         for (Object arg : args) {
458             Assertions.assertEquals(caseRecordPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
459         }
460     }
461 
462     record R(Number n) {}
463     @Reflect
464     static String caseRecordPattern(Object o) {
465         String r = "";
466         switch (o) {
467             case R(Number n) -> r += "R(_)";
468             default -> r+= "else";
469         }
470         return r;
471     }
472 
473     @Test
474     void testCasePatternGuard() {
475         CoreOp.FuncOp lmodel = lower("casePatternGuard");
476         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
477         for (Object arg : args) {
478             Assertions.assertEquals(casePatternGuard(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
479         }
480     }
481 
482     @Reflect
483     static String casePatternGuard(Object obj) {
484         String r = "";
485         switch (obj) {
486             case String s when s.length() > 3 -> r += "str with length > %d".formatted(s.length());
487             case R(Number n) when n.getClass().equals(Double.class) -> r += "R(Double)";
488             default -> r += "else";
489         }
490         return r;
491     }
492 
493     @Test
494     void testDefaultCaseNotTheLast() {
495         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
496         String[] args = {"something", "M", "A"};
497         for (String arg : args) {
498             Assertions.assertEquals(defaultCaseNotTheLast(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
499         }
500     }
501 
502     @Reflect
503     static String defaultCaseNotTheLast(String s) {
504         String r = "";
505         switch (s) {
506             default -> r += "else";
507             case "M" -> r += "Mow";
508             case "A" -> r += "Aow";
509         }
510         return r;
511     }
512 
513     @Test
514     void testTryAndSwitch() {
515         CoreOp.FuncOp lmodel = lower("tryAndSwitch");
516         String[] args = {"A", "B"};
517         for (String arg : args) {
518             Assertions.assertEquals(tryAndSwitch(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
519         }
520     }
521 
522     @Reflect
523     private static List<String> tryAndSwitch(String s) {
524         List<String> r = new ArrayList<>();
525         try {
526             switch (s) {
527                 case "A":
528                     r.add("A");
529                     return r;
530                 default:
531                     r.add("B");
532             }
533         } finally {
534             r.add("finally");
535         }
536         return r;
537     }
538 
539     private static CoreOp.FuncOp lower(String methodName) {
540         return lower(getCodeModel(methodName));
541     }
542 
543     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
544         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
545 
546         CoreOp.FuncOp lf = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
547         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
548 
549         return lf;
550     }
551 
552     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
553         StringWriter sw = new StringWriter();
554         new OpWriter(sw, options).writeOp(f);
555         try {
556             os.write(sw.toString().getBytes());
557         } catch (IOException e) {
558             throw new RuntimeException(e);
559         }
560     }
561 
562     private static CoreOp.FuncOp getCodeModel(String methodName) {
563         Optional<Method> om = Stream.of(TestSwitchStatementOp.class.getDeclaredMethods())
564                 .filter(m -> m.getName().equals(methodName))
565                 .findFirst();
566 
567         return CoreOp.ofMethod(om.get()).get();
568     }
569 }