1 import org.testng.Assert;
  2 import org.testng.annotations.Test;
  3 
  4 import java.io.IOException;
  5 import java.io.OutputStream;
  6 import java.io.StringWriter;
  7 import java.lang.invoke.MethodHandles;
  8 import java.lang.reflect.Method;
  9 import jdk.incubator.code.OpTransformer;
 10 import jdk.incubator.code.interpreter.Interpreter;
 11 import jdk.incubator.code.op.CoreOp;
 12 import jdk.incubator.code.writer.OpWriter;
 13 import jdk.incubator.code.CodeReflection;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18  * @test
 19  * @modules jdk.incubator.code
 20  * @run testng TestSwitchExpressionOp
 21  */
 22 public class TestSwitchExpressionOp {
 23 
 24     @Test
 25     void testCasePatternGuard() {
 26         CoreOp.FuncOp lmodel = lower("casePatternGuard");
 27         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
 28         for (Object arg : args) {
 29             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternGuard(arg));
 30         }
 31     }
 32     @CodeReflection
 33     static String casePatternGuard(Object obj) {
 34         return switch (obj) {
 35             case String s when s.length() > 3 -> "str with length > %d".formatted(s.length());
 36             case R(Number n) when n.getClass().equals(Double.class) -> "R(Double)";
 37             default -> "else";
 38         };
 39     }
 40 
 41     @Test
 42     void testCaseRecordPattern() {
 43         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
 44         // @@@ test with generic record
 45         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
 46         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
 47         for (Object arg : args) {
 48             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseRecordPattern(arg));
 49         }
 50     }
 51     record R(Number n) {}
 52     @CodeReflection
 53     static String caseRecordPattern(Object o) {
 54         return switch (o) {
 55             case R(Number _) -> "R(_)";
 56             default -> "else";
 57         };
 58     }
 59     @Test
 60     void testCaseTypePattern() {
 61         CoreOp.FuncOp lmodel = lower("caseTypePattern");
 62         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
 63         for (Object arg : args) {
 64             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseTypePattern(arg));
 65         }
 66     }
 67     @CodeReflection
 68     static String caseTypePattern(Object o) {
 69         return switch (o) {
 70             case String _ -> "String"; // class
 71             case RandomAccess _ -> "RandomAccess"; // interface
 72             case int[] _ -> "int[]"; // array primitive
 73             case Stack[][] _ -> "Stack[][]"; // array class
 74             case Collection[][][] _ -> "Collection[][][]"; // array interface
 75             case final Number n -> "Number"; // final modifier
 76             default -> "something else";
 77         };
 78     }
 79 
 80     @Test
 81     void testCasePatternWithCaseConstant() {
 82         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
 83         int[] args = {42, 43, -44, 0};
 84         for (int arg : args) {
 85             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternWithCaseConstant(arg));
 86         }
 87     }
 88 
 89     @CodeReflection
 90     static String casePatternWithCaseConstant(Integer a) {
 91         return switch (a) {
 92             case 42 -> "forty two";
 93             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
 94             case Integer i when i > 0 -> "positive int";
 95             case Integer i when i < 0 -> "negative int";
 96             default -> "zero";
 97         };
 98     }
 99 
100     // @Test
101     void testCasePatternMultiLabel() {
102         CoreOp.FuncOp lmodel = lower("casePatternMultiLabel");
103         Object[] args = {(byte) 1, (short) 2, 'A', 3, 4L, 5f, 6d, true, "str"};
104         for (Object arg : args) {
105             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternMultiLabel(arg));
106         }
107     }
108     // @CodeReflection
109     // code model for such as code is not supported
110     // @@@ support this case and uncomment its test
111     private static String casePatternMultiLabel(Object o) {
112         return switch (o) {
113             case Integer _, Long _, Character _, Byte _, Short _-> "integral type";
114             default -> "non integral type";
115         };
116     }
117 
118     @Test
119     void testCasePatternThrow() {
120         CoreOp.FuncOp lmodel = lower("casePatternThrow");
121 
122         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
123         for (Object arg : args) {
124             Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
125         }
126 
127         Object[] args2 = {"abc", List.of()};
128         for (Object arg : args2) {
129             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternThrow(arg));
130         }
131     }
132 
133     @CodeReflection
134     private static String casePatternThrow(Object o) {
135         return switch (o) {
136             case Number n -> throw new IllegalArgumentException();
137             case String s -> "a string";
138             default -> o.getClass().getName();
139         };
140     }
141 
142     @Test
143     void testCasePatternBehaviorIsSyntaxIndependent() {
144         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
145         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
146         CoreOp.FuncOp statement = lower("casePatternStatement");
147 
148         Object[] args = {1, "2", 3L};
149 
150         for (Object arg : args) {
151             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg));
152             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), statement, arg));
153         }
154     }
155 
156     @CodeReflection
157     private static String casePatternRuleExpression(Object o) {
158         return switch (o) {
159             case Integer i -> "integer";
160             case String s -> "string";
161             default -> "not integer nor string";
162         };
163     }
164 
165     @CodeReflection
166     private static String casePatternRuleBlock(Object o) {
167         return switch (o) {
168             case Integer i -> {
169                 yield "integer";
170             }
171             case String s -> {
172                 yield "string";
173             }
174             default -> {
175                 yield "not integer nor string";
176             }
177         };
178     }
179 
180     @CodeReflection
181     private static String casePatternStatement(Object o) {
182         return switch (o) {
183             case Integer i: yield "integer";
184             case String s: yield "string";
185             default: yield "not integer nor string";
186         };
187     }
188 
189     @Test
190     void testCaseConstantOtherKindsOfExpr() {
191         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
192         for (int i = 0; i < 14; i++) {
193             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantOtherKindsOfExpr(i));
194         }
195     }
196 
197     static class Constants {
198         static final int c1 = 12;
199     }
200 
201     @CodeReflection
202     private static String caseConstantOtherKindsOfExpr(int i) {
203         final int eleven = 11;
204         return switch (i) {
205             case 1 & 0xF -> "1";
206             case 4>>1 -> "2";
207             case (int) 3L -> "3";
208             case 2<<1 -> "4";
209             case 10 / 2 -> "5";
210             case 12 - 6 -> "6";
211             case 3 + 4 -> "7";
212             case 2 * 2 * 2 -> "8";
213             case 8 | 1 -> "9";
214             case (10) -> "10";
215             case eleven -> "11";
216             case Constants.c1 -> String.valueOf(Constants.c1);
217             case 1 > 0 ? 13 : 133 -> "13";
218             default -> "an int";
219         };
220     }
221 
222     @Test
223     void testCaseConstantEnum() {
224         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
225         for (Day day : Day.values()) {
226             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, day), caseConstantEnum(day));
227         }
228     }
229 
230     enum Day {
231         MON, TUE, WED, THU, FRI, SAT, SUN
232     }
233 
234     @CodeReflection
235     private static int caseConstantEnum(Day d) {
236         return switch (d) {
237             case MON, FRI, SUN -> 6;
238             case TUE -> 7;
239             case THU, SAT -> 8;
240             case WED -> 9;
241         };
242     }
243 
244     @Test
245     void testCaseConstantFallThrough() {
246         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
247         char[] args = {'A', 'B', 'C'};
248         for (char arg : args) {
249             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantFallThrough(arg));
250         }
251     }
252 
253     @CodeReflection
254     private static String caseConstantFallThrough(char c) {
255         return switch (c) {
256             case 'A':
257             case 'B':
258                 yield "A or B";
259             default:
260                 yield "Neither A nor B";
261         };
262     }
263 
264     // @CodeReflection
265     // compiler code doesn't support case null, default
266     // @@@ support such as case and test the switch expression lowering for this case
267     private static String caseConstantNullAndDefault(String s) {
268         return switch (s) {
269             case "abc" -> "alphabet";
270             case null, default -> "null or default";
271         };
272     }
273 
274     @Test
275     void testCaseConstantNullLabel() {
276         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
277         String[] args = {null, "non null"};
278         for (String arg : args) {
279             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantNullLabel(arg));
280         }
281     }
282 
283     @CodeReflection
284     private static String caseConstantNullLabel(String s) {
285         return switch (s) {
286             case null -> "null";
287             default -> "non null";
288         };
289     }
290 
291     @Test
292     void testCaseConstantThrow() {
293         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
294         Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, 8));
295         int[] args = {9, 10};
296         for (int arg : args) {
297             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantThrow(arg));
298         }
299     }
300 
301     @CodeReflection
302     private static String caseConstantThrow(Integer i) {
303         return switch (i) {
304             case 8 -> throw new IllegalArgumentException();
305             case 9 -> "NINE";
306             default -> "An integer";
307         };
308     }
309 
310     @Test
311     void testCaseConstantMultiLabels() {
312         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
313         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
314         for (char arg : args) {
315             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantMultiLabels(arg));
316         }
317     }
318 
319     @CodeReflection
320     private static String caseConstantMultiLabels(char c) {
321         return switch (Character.toLowerCase(c)) {
322             case 'a', 'e', 'i', 'o', 'u': yield "vowel";
323             default: yield "consonant";
324         };
325     }
326 
327     @Test
328     void testCaseConstantBehaviorIsSyntaxIndependent() {
329         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
330         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
331         CoreOp.FuncOp statement = lower("caseConstantStatement");
332 
333         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
334 
335         for (String arg : args) {
336             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg));
337             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), statement, arg));
338         }
339     }
340 
341     @CodeReflection
342     public static String caseConstantRuleExpression(String r) {
343         return switch (r) {
344             case "FOO" -> "BAR";
345             case "BAR" -> "BAZ";
346             case "BAZ" -> "FOO";
347             default -> "";
348         };
349     }
350 
351     @CodeReflection
352     public static String caseConstantRuleBlock(String r) {
353         return switch (r) {
354             case "FOO" -> {
355                 yield "BAR";
356             }
357             case "BAR" -> {
358                 yield "BAZ";
359             }
360             case "BAZ" -> {
361                 yield "FOO";
362             }
363             default -> {
364                 yield "";
365             }
366         };
367     }
368 
369     @CodeReflection
370     private static String caseConstantStatement(String s) {
371         return switch (s) {
372             case "FOO": yield "BAR";
373             case "BAR": yield "BAZ";
374             case "BAZ": yield "FOO";
375             default: yield "";
376         };
377     }
378 
379     @Test
380     void testCaseConstantConv() {
381         CoreOp.FuncOp lmodel = lower("caseConstantConv");
382         short[] args = {1, 2, 3, 4};
383         for (short arg : args) {
384             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantConv(arg));
385         }
386     }
387 
388     @CodeReflection
389     static String caseConstantConv(short a) {
390         final short s = 1;
391         final byte b = 2;
392         return switch (a) {
393             case s -> "one"; // identity
394             case b -> "three"; // widening primitive conversion
395             case 3 -> "two"; // narrowing primitive conversion
396             default -> "default";
397         };
398     }
399 
400     @Test
401     void testCaseConstantConv2() {
402         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
403         Byte[] args = {1, 2, 3};
404         for (Byte arg : args) {
405             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantConv2(arg));
406         }
407     }
408 
409     @CodeReflection
410     static String caseConstantConv2(Byte a) {
411         final byte b = 2;
412         return switch (a) {
413             case 1 -> "one"; // narrowing primitive conversion followed by a boxing conversion
414             case b -> "two"; // boxing
415             default -> "default";
416         };
417     }
418 
419     @Test
420     void testUnconditionalPattern() {
421         CoreOp.FuncOp lmodel = lower("unconditionalPattern");
422         String[] args = {"A", "X"};
423         for (String arg : args) {
424             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), unconditionalPattern(arg));
425         }
426     }
427 
428     @CodeReflection
429     static String unconditionalPattern(String s) {
430         return switch (s) {
431             case "A" -> "A";
432             case Object o -> "default";
433         };
434     }
435 
436 
437     @Test
438     void testDefaultCaseNotTheLast() {
439         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
440         String[] args = {"something", "M", "A"};
441         for (String arg : args) {
442             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), defaultCaseNotTheLast(arg));
443         }
444     }
445 
446     @CodeReflection
447     static String defaultCaseNotTheLast(String s) {
448         return switch (s) {
449             default -> "else";
450             case "M" -> "Mow";
451             case "A" -> "Aow";
452         };
453     }
454 
455     // we are not testing switch expr that has no default,
456     // because to test for MatchException we need to set up separate compilation
457     // in compiler tests we are checking that the code model contains a default case that throws MatchException
458     // that should be enough
459 
460     private static CoreOp.FuncOp lower(String methodName) {
461         return lower(getCodeModel(methodName));
462     }
463 
464     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
465         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
466 
467         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
468         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
469 
470         return lf;
471     }
472 
473     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
474         StringWriter sw = new StringWriter();
475         new OpWriter(sw, options).writeOp(f);
476         try {
477             os.write(sw.toString().getBytes());
478         } catch (IOException e) {
479             throw new RuntimeException(e);
480         }
481     }
482 
483     private static CoreOp.FuncOp getCodeModel(String methodName) {
484         Optional<Method> om = Stream.of(TestSwitchExpressionOp.class.getDeclaredMethods())
485                 .filter(m -> m.getName().equals(methodName))
486                 .findFirst();
487 
488         return CoreOp.ofMethod(om.get()).get();
489     }
490 }