1 import org.testng.Assert;
  2 import org.testng.annotations.Test;
  3 
  4 import java.io.IOException;
  5 import java.io.OutputStream;
  6 import java.io.StringWriter;
  7 import java.lang.invoke.MethodHandles;
  8 import java.lang.reflect.Method;
  9 import java.lang.reflect.code.OpTransformer;
 10 import java.lang.reflect.code.interpreter.Interpreter;
 11 import java.lang.reflect.code.op.CoreOp;
 12 import java.lang.reflect.code.writer.OpWriter;
 13 import java.lang.runtime.CodeReflection;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18  * @test
 19  * @run testng TestSwitchExpressionOp
 20  */
 21 public class TestSwitchExpressionOp {
 22 
 23     @Test
 24     void testCasePatternGuard() {
 25         CoreOp.FuncOp lmodel = lower("casePatternGuard");
 26         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
 27         for (Object arg : args) {
 28             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternGuard(arg));
 29         }
 30     }
 31     @CodeReflection
 32     static String casePatternGuard(Object obj) {
 33         return switch (obj) {
 34             case String s when s.length() > 3 -> "str with length > %d".formatted(s.length());
 35             case R(Number n) when n.getClass().equals(Double.class) -> "R(Double)";
 36             default -> "else";
 37         };
 38     }
 39 
 40     @Test
 41     void testCaseRecordPattern() {
 42         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
 43         // @@@ test with generic record
 44         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
 45         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
 46         for (Object arg : args) {
 47             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseRecordPattern(arg));
 48         }
 49     }
 50     record R(Number n) {}
 51     @CodeReflection
 52     static String caseRecordPattern(Object o) {
 53         return switch (o) {
 54             case R(Number _) -> "R(_)";
 55             default -> "else";
 56         };
 57     }
 58     @Test
 59     void testCaseTypePattern() {
 60         CoreOp.FuncOp lmodel = lower("caseTypePattern");
 61         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
 62         for (Object arg : args) {
 63             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseTypePattern(arg));
 64         }
 65     }
 66     @CodeReflection
 67     static String caseTypePattern(Object o) {
 68         return switch (o) {
 69             case String _ -> "String"; // class
 70             case RandomAccess _ -> "RandomAccess"; // interface
 71             case int[] _ -> "int[]"; // array primitive
 72             case Stack[][] _ -> "Stack[][]"; // array class
 73             case Collection[][][] _ -> "Collection[][][]"; // array interface
 74             case final Number n -> "Number"; // final modifier
 75             default -> "something else";
 76         };
 77     }
 78 
 79     @Test
 80     void testCasePatternWithCaseConstant() {
 81         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
 82         int[] args = {42, 43, -44, 0};
 83         for (int arg : args) {
 84             Assert.assertEquals(Interpreter.invoke(lmodel, arg), casePatternWithCaseConstant(arg));
 85         }
 86     }
 87 
 88     @CodeReflection
 89     static String casePatternWithCaseConstant(Integer a) {
 90         return switch (a) {
 91             case 42 -> "forty two";
 92             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
 93             case Integer i when i > 0 -> "positive int";
 94             case Integer i when i < 0 -> "negative int";
 95             default -> "zero";
 96         };
 97     }
 98 
 99     // @Test
100     void testCasePatternMultiLabel() {
101         CoreOp.FuncOp lmodel = lower("casePatternMultiLabel");
102         Object[] args = {(byte) 1, (short) 2, 'A', 3, 4L, 5f, 6d, true, "str"};
103         for (Object arg : args) {
104             Assert.assertEquals(Interpreter.invoke(lmodel, arg), casePatternMultiLabel(arg));
105         }
106     }
107     // @CodeReflection
108     // code model for such as code is not supported
109     // @@@ support this case and uncomment its test
110     private static String casePatternMultiLabel(Object o) {
111         return switch (o) {
112             case Integer _, Long _, Character _, Byte _, Short _-> "integral type";
113             default -> "non integral type";
114         };
115     }
116 
117     @Test
118     void testCasePatternThrow() {
119         CoreOp.FuncOp lmodel = lower("casePatternThrow");
120 
121         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
122         for (Object arg : args) {
123             Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(lmodel, arg));
124         }
125 
126         Object[] args2 = {"abc", List.of()};
127         for (Object arg : args2) {
128             Assert.assertEquals(Interpreter.invoke(lmodel, arg), casePatternThrow(arg));
129         }
130     }
131 
132     @CodeReflection
133     private static String casePatternThrow(Object o) {
134         return switch (o) {
135             case Number n -> throw new IllegalArgumentException();
136             case String s -> "a string";
137             default -> o.getClass().getName();
138         };
139     }
140 
141     @Test
142     void testCasePatternBehaviorIsSyntaxIndependent() {
143         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
144         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
145         CoreOp.FuncOp statement = lower("casePatternStatement");
146 
147         Object[] args = {1, "2", 3L};
148 
149         for (Object arg : args) {
150             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(ruleBlock, arg));
151             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(statement, arg));
152         }
153     }
154 
155     @CodeReflection
156     private static String casePatternRuleExpression(Object o) {
157         return switch (o) {
158             case Integer i -> "integer";
159             case String s -> "string";
160             default -> "not integer nor string";
161         };
162     }
163 
164     @CodeReflection
165     private static String casePatternRuleBlock(Object o) {
166         return switch (o) {
167             case Integer i -> {
168                 yield "integer";
169             }
170             case String s -> {
171                 yield "string";
172             }
173             default -> {
174                 yield "not integer nor string";
175             }
176         };
177     }
178 
179     @CodeReflection
180     private static String casePatternStatement(Object o) {
181         return switch (o) {
182             case Integer i: yield "integer";
183             case String s: yield "string";
184             default: yield "not integer nor string";
185         };
186     }
187 
188     @Test
189     void testCaseConstantOtherKindsOfExpr() {
190         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
191         for (int i = 0; i < 14; i++) {
192             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantOtherKindsOfExpr(i));
193         }
194     }
195 
196     static class Constants {
197         static final int c1 = 12;
198     }
199 
200     @CodeReflection
201     private static String caseConstantOtherKindsOfExpr(int i) {
202         final int eleven = 11;
203         return switch (i) {
204             case 1 & 0xF -> "1";
205             case 4>>1 -> "2";
206             case (int) 3L -> "3";
207             case 2<<1 -> "4";
208             case 10 / 2 -> "5";
209             case 12 - 6 -> "6";
210             case 3 + 4 -> "7";
211             case 2 * 2 * 2 -> "8";
212             case 8 | 1 -> "9";
213             case (10) -> "10";
214             case eleven -> "11";
215             case Constants.c1 -> String.valueOf(Constants.c1);
216             case 1 > 0 ? 13 : 133 -> "13";
217             default -> "an int";
218         };
219     }
220 
221     @Test
222     void testCaseConstantEnum() {
223         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
224         for (Day day : Day.values()) {
225             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, day), caseConstantEnum(day));
226         }
227     }
228 
229     enum Day {
230         MON, TUE, WED, THU, FRI, SAT, SUN
231     }
232 
233     @CodeReflection
234     private static int caseConstantEnum(Day d) {
235         return switch (d) {
236             case MON, FRI, SUN -> 6;
237             case TUE -> 7;
238             case THU, SAT -> 8;
239             case WED -> 9;
240         };
241     }
242 
243     @Test
244     void testCaseConstantFallThrough() {
245         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
246         char[] args = {'A', 'B', 'C'};
247         for (char arg : args) {
248             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantFallThrough(arg));
249         }
250     }
251 
252     @CodeReflection
253     private static String caseConstantFallThrough(char c) {
254         return switch (c) {
255             case 'A':
256             case 'B':
257                 yield "A or B";
258             default:
259                 yield "Neither A nor B";
260         };
261     }
262 
263     // @CodeReflection
264     // compiler code doesn't support case null, default
265     // @@@ support such as case and test the switch expression lowering for this case
266     private static String caseConstantNullAndDefault(String s) {
267         return switch (s) {
268             case "abc" -> "alphabet";
269             case null, default -> "null or default";
270         };
271     }
272 
273     @Test
274     void testCaseConstantNullLabel() {
275         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
276         String[] args = {null, "non null"};
277         for (String arg : args) {
278             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantNullLabel(arg));
279         }
280     }
281 
282     @CodeReflection
283     private static String caseConstantNullLabel(String s) {
284         return switch (s) {
285             case null -> "null";
286             default -> "non null";
287         };
288     }
289 
290     @Test
291     void testCaseConstantThrow() {
292         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
293         Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(lmodel, 8));
294         int[] args = {9, 10};
295         for (int arg : args) {
296             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantThrow(arg));
297         }
298     }
299 
300     @CodeReflection
301     private static String caseConstantThrow(Integer i) {
302         return switch (i) {
303             case 8 -> throw new IllegalArgumentException();
304             case 9 -> "NINE";
305             default -> "An integer";
306         };
307     }
308 
309     @Test
310     void testCaseConstantMultiLabels() {
311         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
312         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
313         for (char arg : args) {
314             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantMultiLabels(arg));
315         }
316     }
317 
318     @CodeReflection
319     private static String caseConstantMultiLabels(char c) {
320         return switch (Character.toLowerCase(c)) {
321             case 'a', 'e', 'i', 'o', 'u': yield "vowel";
322             default: yield "consonant";
323         };
324     }
325 
326     @Test
327     void testCaseConstantBehaviorIsSyntaxIndependent() {
328         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
329         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
330         CoreOp.FuncOp statement = lower("caseConstantStatement");
331 
332         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
333 
334         for (String arg : args) {
335             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(ruleBlock, arg));
336             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(statement, arg));
337         }
338     }
339 
340     @CodeReflection
341     public static String caseConstantRuleExpression(String r) {
342         return switch (r) {
343             case "FOO" -> "BAR";
344             case "BAR" -> "BAZ";
345             case "BAZ" -> "FOO";
346             default -> "";
347         };
348     }
349 
350     @CodeReflection
351     public static String caseConstantRuleBlock(String r) {
352         return switch (r) {
353             case "FOO" -> {
354                 yield "BAR";
355             }
356             case "BAR" -> {
357                 yield "BAZ";
358             }
359             case "BAZ" -> {
360                 yield "FOO";
361             }
362             default -> {
363                 yield "";
364             }
365         };
366     }
367 
368     @CodeReflection
369     private static String caseConstantStatement(String s) {
370         return switch (s) {
371             case "FOO": yield "BAR";
372             case "BAR": yield "BAZ";
373             case "BAZ": yield "FOO";
374             default: yield "";
375         };
376     }
377 
378     @Test
379     void testCaseConstantConv() {
380         CoreOp.FuncOp lmodel = lower("caseConstantConv");
381         short[] args = {1, 2, 3, 4};
382         for (short arg : args) {
383             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantConv(arg));
384         }
385     }
386 
387     @CodeReflection
388     static String caseConstantConv(short a) {
389         final short s = 1;
390         final byte b = 2;
391         return switch (a) {
392             case s -> "one"; // identity
393             case b -> "three"; // widening primitive conversion
394             case 3 -> "two"; // narrowing primitive conversion
395             default -> "default";
396         };
397     }
398 
399     @Test
400     void testCaseConstantConv2() {
401         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
402         Byte[] args = {1, 2, 3};
403         for (Byte arg : args) {
404             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantConv2(arg));
405         }
406     }
407 
408     @CodeReflection
409     static String caseConstantConv2(Byte a) {
410         final byte b = 2;
411         return switch (a) {
412             case 1 -> "one"; // narrowing primitive conversion followed by a boxing conversion
413             case b -> "two"; // boxing
414             default -> "default";
415         };
416     }
417 
418     @Test
419     void testUnconditionalPattern() {
420         CoreOp.FuncOp lmodel = lower("unconditionalPattern");
421         String[] args = {"A", "X"};
422         for (String arg : args) {
423             Assert.assertEquals(Interpreter.invoke(lmodel, arg), unconditionalPattern(arg));
424         }
425     }
426 
427     @CodeReflection
428     static String unconditionalPattern(String s) {
429         return switch (s) {
430             case "A" -> "A";
431             case Object o -> "default";
432         };
433     }
434 
435 
436     @Test
437     void testDefaultCaseNotTheLast() {
438         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
439         String[] args = {"something", "M", "A"};
440         for (String arg : args) {
441             Assert.assertEquals(Interpreter.invoke(lmodel, arg), defaultCaseNotTheLast(arg));
442         }
443     }
444 
445     @CodeReflection
446     static String defaultCaseNotTheLast(String s) {
447         return switch (s) {
448             default -> "else";
449             case "M" -> "Mow";
450             case "A" -> "Aow";
451         };
452     }
453 
454     // we are not testing switch expr that has no default,
455     // because to test for MatchException we need to set up separate compilation
456     // in compiler tests we are checking that the code model contains a default case that throws MatchException
457     // that should be enough
458 
459     private static CoreOp.FuncOp lower(String methodName) {
460         return lower(getCodeModel(methodName));
461     }
462 
463     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
464         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
465 
466         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
467         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
468 
469         return lf;
470     }
471 
472     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
473         StringWriter sw = new StringWriter();
474         new OpWriter(sw, options).writeOp(f);
475         try {
476             os.write(sw.toString().getBytes());
477         } catch (IOException e) {
478             throw new RuntimeException(e);
479         }
480     }
481 
482     private static CoreOp.FuncOp getCodeModel(String methodName) {
483         Optional<Method> om = Stream.of(TestSwitchExpressionOp.class.getDeclaredMethods())
484                 .filter(m -> m.getName().equals(methodName))
485                 .findFirst();
486 
487         return om.get().getCodeModel().get();
488     }
489 }