1 import org.testng.Assert;
  2 import org.testng.annotations.Test;
  3 
  4 import java.io.IOException;
  5 import java.io.OutputStream;
  6 import java.io.StringWriter;
  7 import java.lang.invoke.MethodHandles;
  8 import java.lang.reflect.Method;
  9 import java.lang.reflect.code.OpTransformer;
 10 import java.lang.reflect.code.interpreter.Interpreter;
 11 import java.lang.reflect.code.op.CoreOp;
 12 import java.lang.reflect.code.writer.OpWriter;
 13 import java.lang.runtime.CodeReflection;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18 * @test
 19 * @run testng TestSwitchStatementOp
 20 * */
 21 public class TestSwitchStatementOp {
 22 
 23     @Test
 24     void testCaseConstantBehaviorIsSyntaxIndependent() {
 25         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
 26         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
 27         CoreOp.FuncOp statement = lower("caseConstantStatement");
 28 
 29         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
 30 
 31         for (String arg : args) {
 32             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(ruleBlock, arg));
 33             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(statement, arg));
 34         }
 35     }
 36 
 37     @CodeReflection
 38     public static String caseConstantRuleExpression(String r) {
 39         String s = "";
 40         switch (r) {
 41             case "FOO" -> s += "BAR";
 42             case "BAR" -> s += "BAZ";
 43             case "BAZ" -> s += "FOO";
 44             default -> s += "else";
 45         }
 46         return s;
 47     }
 48 
 49     @CodeReflection
 50     public static String caseConstantRuleBlock(String r) {
 51         String s = "";
 52         switch (r) {
 53             case "FOO" -> {
 54                 s += "BAR";
 55             }
 56             case "BAR" -> {
 57                 s += "BAZ";
 58             }
 59             case "BAZ" -> {
 60                 s += "FOO";
 61             }
 62             default -> {
 63                 s += "else";
 64             }
 65         }
 66         return s;
 67     }
 68 
 69     @CodeReflection
 70     private static String caseConstantStatement(String s) {
 71         String r = "";
 72         switch (s) {
 73             case "FOO":
 74                 r += "BAR";
 75                 break;
 76             case "BAR":
 77                 r += "BAZ";
 78                 break;
 79             case "BAZ":
 80                 r += "FOO";
 81                 break;
 82             default:
 83                 r += "else";
 84         }
 85         return r;
 86     }
 87 
 88     @Test
 89     void testCaseConstantMultiLabels() {
 90         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
 91         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
 92         for (char arg : args) {
 93             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantMultiLabels(arg));
 94         }
 95     }
 96 
 97     @CodeReflection
 98     private static String caseConstantMultiLabels(char c) {
 99         String r = "";
100         switch (Character.toLowerCase(c)) {
101             case 'a', 'e', 'i', 'o', 'u':
102                 r += "vowel";
103                 break;
104             default:
105                 r += "consonant";
106         }
107         return r;
108     }
109 
110     @Test
111     void testCaseConstantThrow() {
112         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
113 
114         Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(lmodel, 8));
115 
116         int[] args = {9, 10};
117         for (int arg : args) {
118             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantThrow(arg));
119         }
120     }
121 
122     @CodeReflection
123     private static String caseConstantThrow(Integer i) {
124         String r = "";
125         switch (i) {
126             case 8 -> throw new IllegalArgumentException();
127             case 9 -> r += "Nine";
128             default -> r += "An integer";
129         }
130         return r;
131     }
132 
133     @Test
134     void testCaseConstantNullLabel() {
135         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
136         String[] args = {null, "non null"};
137         for (String arg : args) {
138             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantNullLabel(arg));
139         }
140     }
141 
142     @CodeReflection
143     private static String caseConstantNullLabel(String s) {
144         String r = "";
145         switch (s) {
146             case null -> r += "null";
147             default -> r += "non null";
148         }
149         return r;
150     }
151 
152     @Test
153     void testCaseConstantFallThrough() {
154         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
155         char[] args = {'A', 'B', 'C'};
156         for (char arg : args) {
157             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseConstantFallThrough(arg));
158         }
159     }
160 
161     @CodeReflection
162     private static String caseConstantFallThrough(char c) {
163         String r = "";
164         switch (c) {
165             case 'A':
166             case 'B':
167                 r += "A or B";
168                 break;
169             default:
170                 r += "Neither A nor B";
171         }
172         return r;
173     }
174 
175     @Test
176     void testCaseConstantEnum() {
177         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
178         for (Day day : Day.values()) {
179             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, day), caseConstantEnum(day));
180         }
181     }
182 
183     enum Day {
184         MON, TUE, WED, THU, FRI, SAT, SUN
185     }
186     @CodeReflection
187     private static String caseConstantEnum(Day d) {
188         String r = "";
189         switch (d) {
190             case MON, FRI, SUN -> r += 6;
191             case TUE -> r += 7;
192             case THU, SAT -> r += 8;
193             case WED -> r += 9;
194         }
195         return r;
196     }
197 
198     @Test
199     void testCaseConstantOtherKindsOfExpr() {
200         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
201         for (int i = 0; i < 14; i++) {
202             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantOtherKindsOfExpr(i));
203         }
204     }
205 
206     static class Constants {
207         static final int c1 = 12;
208     }
209     @CodeReflection
210     private static String caseConstantOtherKindsOfExpr(int i) {
211         String r = "";
212         final int eleven = 11;
213         switch (i) {
214             case 1 & 0xF -> r += 1;
215             case 4>>1 -> r += "2";
216             case (int) 3L -> r += 3;
217             case 2<<1 -> r += 4;
218             case 10 / 2 -> r += 5;
219             case 12 - 6 -> r += 6;
220             case 3 + 4 -> r += 7;
221             case 2 * 2 * 2 -> r += 8;
222             case 8 | 1 -> r += 9;
223             case (10) -> r += 10;
224             case eleven -> r += 11;
225             case Constants.c1 -> r += Constants.c1;
226             case 1 > 0 ? 13 : 133 -> r += 13;
227             default -> r += "an int";
228         }
229         return r;
230     }
231 
232     @Test
233     void testCaseConstantConv() {
234         CoreOp.FuncOp lmodel = lower("caseConstantConv");
235         for (short i = 1; i < 5; i++) {
236             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantConv(i));
237         }
238     }
239 
240     @CodeReflection
241     static String caseConstantConv(short a) {
242         final short s = 1;
243         final byte b = 2;
244         String r = "";
245         switch (a) {
246             case s -> r += "one"; // identity, short -> short
247             case b -> r += "two"; // widening primitive conversion, byte -> short
248             case 3 -> r += "three"; // narrowing primitive conversion, int -> short
249             default -> r += "else";
250         }
251         return r;
252     }
253 
254     @Test
255     void testCaseConstantConv2() {
256         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
257         Byte[] args = {1, 2, 3};
258         for (Byte arg : args) {
259             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantConv2(arg));
260         }
261     }
262 
263     @CodeReflection
264     static String caseConstantConv2(Byte a) {
265         final byte b = 2;
266         String r = "";
267         switch (a) {
268             case 1 -> r+= "one"; // narrowing primitive conversion followed by a boxing conversion, int -> bye -> Byte
269             case b -> r+= "two"; // boxing, byte -> Byte
270             default -> r+= "default";
271         }
272         return r;
273     }
274 
275     @Test
276     void testNonEnhancedSwStatNoDefault() {
277         CoreOp.FuncOp lmodel = lower("nonEnhancedSwStatNoDefault");
278         for (int i = 1; i < 4; i++) {
279             Assert.assertEquals(Interpreter.invoke(lmodel, i), nonEnhancedSwStatNoDefault(i));
280         }
281     }
282 
283     @CodeReflection
284     static String nonEnhancedSwStatNoDefault(int a) {
285         String r = "";
286         switch (a) {
287             case 1 -> r += "1";
288             case 2 -> r += 2;
289         }
290         return r;
291     }
292 
293     // no reason to test enhanced switch statement that has no default
294     // because we can't test for MatchException without separate compilation
295 
296     @Test
297     void testEnhancedSwStatUnconditionalPattern() {
298         CoreOp.FuncOp lmodel = lower("enhancedSwStatUnconditionalPattern");
299         String[] args = {"A", "B"};
300         for (String arg : args) {
301             Assert.assertEquals(Interpreter.invoke(lmodel, arg), enhancedSwStatUnconditionalPattern(arg));
302         }
303     }
304 
305     @CodeReflection
306     static String enhancedSwStatUnconditionalPattern(String s) {
307         String r = "";
308         switch (s) {
309             case "A" -> r += "A";
310             case Object o -> r += "obj";
311         }
312         return r;
313     }
314 
315     @Test
316     void testCasePatternBehaviorIsSyntaxIndependent() {
317         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
318         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
319         CoreOp.FuncOp statement = lower("casePatternStatement");
320 
321         Object[] args = {1, "2", 3L};
322 
323         for (Object arg : args) {
324             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(ruleBlock, arg));
325             Assert.assertEquals(Interpreter.invoke(ruleExpression, arg), Interpreter.invoke(statement, arg));
326         }
327     }
328 
329     @CodeReflection
330     private static String casePatternRuleExpression(Object o) {
331         String r = "";
332         switch (o) {
333             case Integer i -> r += "integer";
334             case String s -> r+= "string";
335             default -> r+= "else";
336         }
337         return r;
338     }
339 
340     @CodeReflection
341     private static String casePatternRuleBlock(Object o) {
342         String r = "";
343         switch (o) {
344             case Integer i -> {
345                 r += "integer";
346             }
347             case String s -> {
348                 r += "string";
349             }
350             default -> {
351                 r += "else";
352             }
353         }
354         return r;
355     }
356 
357     @CodeReflection
358     private static String casePatternStatement(Object o) {
359         String r = "";
360         switch (o) {
361             case Integer i:
362                 r += "integer";
363                 break;
364             case String s:
365                 r += "string";
366                 break;
367             default:
368                 r += "else";
369         }
370         return r;
371     }
372 
373     @Test
374     void testCasePatternThrow() {
375         CoreOp.FuncOp lmodel = lower("casePatternThrow");
376 
377         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
378         for (Object arg : args) {
379             Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(lmodel, arg));
380         }
381 
382         Object[] args2 = {"abc", List.of()};
383         for (Object arg : args2) {
384             Assert.assertEquals(Interpreter.invoke(lmodel, arg), casePatternThrow(arg));
385         }
386     }
387 
388     @CodeReflection
389     private static String casePatternThrow(Object o) {
390         String r = "";
391         switch (o) {
392             case Number n -> throw new IllegalArgumentException();
393             case String s -> r += "a string";
394             default -> r += o.getClass().getName();
395         }
396         return r;
397     }
398 
399     // @@@ when multi patterns is supported, we will test it
400 
401     @Test
402     void testCasePatternWithCaseConstant() {
403         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
404         int[] args = {42, 43, -44, 0};
405         for (int arg : args) {
406             Assert.assertEquals(Interpreter.invoke(lmodel, arg), casePatternWithCaseConstant(arg));
407         }
408     }
409 
410     @CodeReflection
411     static String casePatternWithCaseConstant(Integer a) {
412         String r = "";
413         switch (a) {
414             case 42 -> r += "forty two";
415             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
416             case Integer i when i > 0 -> r += "positive int";
417             case Integer i when i < 0 -> r += "negative int";
418             default -> r += "zero";
419         }
420         return r;
421     }
422 
423     @Test
424     void testCaseTypePattern() {
425         CoreOp.FuncOp lmodel = lower("caseTypePattern");
426         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
427         for (Object arg : args) {
428             Assert.assertEquals(Interpreter.invoke(lmodel, arg), caseTypePattern(arg));
429         }
430     }
431 
432     @CodeReflection
433     static String caseTypePattern(Object o) {
434         String r = "";
435         switch (o) {
436             case String _ -> r+= "String"; // class
437             case RandomAccess _ -> r+= "RandomAccess"; // interface
438             case int[] _ -> r+= "int[]"; // array primitive
439             case Stack[][] _ -> r+= "Stack[][]"; // array class
440             case Collection[][][] _ -> r+= "Collection[][][]"; // array interface
441             case final Number n -> r+= "Number"; // final modifier
442             default -> r+= "something else";
443         }
444         return r;
445     }
446 
447     @Test
448     void testCaseRecordPattern() {
449         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
450         // @@@ test with generic record
451         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
452         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
453         for (Object arg : args) {
454             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseRecordPattern(arg));
455         }
456     }
457 
458     record R(Number n) {}
459     @CodeReflection
460     static String caseRecordPattern(Object o) {
461         String r = "";
462         switch (o) {
463             case R(Number n) -> r += "R(_)";
464             default -> r+= "else";
465         }
466         return r;
467     }
468 
469     @Test
470     void testCasePatternGuard() {
471         CoreOp.FuncOp lmodel = lower("casePatternGuard");
472         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
473         for (Object arg : args) {
474             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternGuard(arg));
475         }
476     }
477 
478     @CodeReflection
479     static String casePatternGuard(Object obj) {
480         String r = "";
481         switch (obj) {
482             case String s when s.length() > 3 -> r += "str with length > %d".formatted(s.length());
483             case R(Number n) when n.getClass().equals(Double.class) -> r += "R(Double)";
484             default -> r += "else";
485         }
486         return r;
487     }
488 
489     @Test
490     void testDefaultCaseNotTheLast() {
491         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
492         String[] args = {"something", "M", "A"};
493         for (String arg : args) {
494             Assert.assertEquals(Interpreter.invoke(lmodel, arg), defaultCaseNotTheLast(arg));
495         }
496     }
497 
498     @CodeReflection
499     static String defaultCaseNotTheLast(String s) {
500         String r = "";
501         switch (s) {
502             default -> r += "else";
503             case "M" -> r += "Mow";
504             case "A" -> r += "Aow";
505         }
506         return r;
507     }
508 
509     private static CoreOp.FuncOp lower(String methodName) {
510         return lower(getCodeModel(methodName));
511     }
512 
513     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
514         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
515 
516         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
517         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
518 
519         return lf;
520     }
521 
522     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
523         StringWriter sw = new StringWriter();
524         new OpWriter(sw, options).writeOp(f);
525         try {
526             os.write(sw.toString().getBytes());
527         } catch (IOException e) {
528             throw new RuntimeException(e);
529         }
530     }
531 
532     private static CoreOp.FuncOp getCodeModel(String methodName) {
533         Optional<Method> om = Stream.of(TestSwitchStatementOp.class.getDeclaredMethods())
534                 .filter(m -> m.getName().equals(methodName))
535                 .findFirst();
536 
537         return om.get().getCodeModel().get();
538     }
539 }