1 import org.testng.Assert;
  2 import org.testng.annotations.Test;
  3 
  4 import java.io.IOException;
  5 import java.io.OutputStream;
  6 import java.io.StringWriter;
  7 import java.lang.invoke.MethodHandles;
  8 import java.lang.reflect.Method;
  9 import jdk.incubator.code.OpTransformer;
 10 import jdk.incubator.code.interpreter.Interpreter;
 11 import jdk.incubator.code.op.CoreOp;
 12 import jdk.incubator.code.writer.OpWriter;
 13 import jdk.incubator.code.CodeReflection;
 14 import java.util.*;
 15 import java.util.stream.Stream;
 16 
 17 /*
 18 * @test
 19 * @modules jdk.incubator.code
 20 * @run testng TestSwitchStatementOp
 21 * */
 22 public class TestSwitchStatementOp {
 23 
 24     @Test
 25     void testCaseConstantBehaviorIsSyntaxIndependent() {
 26         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
 27         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
 28         CoreOp.FuncOp statement = lower("caseConstantStatement");
 29 
 30         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
 31 
 32         for (String arg : args) {
 33             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg));
 34             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), statement, arg));
 35         }
 36     }
 37 
 38     @CodeReflection
 39     public static String caseConstantRuleExpression(String r) {
 40         String s = "";
 41         switch (r) {
 42             case "FOO" -> s += "BAR";
 43             case "BAR" -> s += "BAZ";
 44             case "BAZ" -> s += "FOO";
 45             default -> s += "else";
 46         }
 47         return s;
 48     }
 49 
 50     @CodeReflection
 51     public static String caseConstantRuleBlock(String r) {
 52         String s = "";
 53         switch (r) {
 54             case "FOO" -> {
 55                 s += "BAR";
 56             }
 57             case "BAR" -> {
 58                 s += "BAZ";
 59             }
 60             case "BAZ" -> {
 61                 s += "FOO";
 62             }
 63             default -> {
 64                 s += "else";
 65             }
 66         }
 67         return s;
 68     }
 69 
 70     @CodeReflection
 71     private static String caseConstantStatement(String s) {
 72         String r = "";
 73         switch (s) {
 74             case "FOO":
 75                 r += "BAR";
 76                 break;
 77             case "BAR":
 78                 r += "BAZ";
 79                 break;
 80             case "BAZ":
 81                 r += "FOO";
 82                 break;
 83             default:
 84                 r += "else";
 85         }
 86         return r;
 87     }
 88 
 89     @Test
 90     void testCaseConstantMultiLabels() {
 91         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
 92         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
 93         for (char arg : args) {
 94             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantMultiLabels(arg));
 95         }
 96     }
 97 
 98     @CodeReflection
 99     private static String caseConstantMultiLabels(char c) {
100         String r = "";
101         switch (Character.toLowerCase(c)) {
102             case 'a', 'e', 'i', 'o', 'u':
103                 r += "vowel";
104                 break;
105             default:
106                 r += "consonant";
107         }
108         return r;
109     }
110 
111     @Test
112     void testCaseConstantThrow() {
113         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
114 
115         Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, 8));
116 
117         int[] args = {9, 10};
118         for (int arg : args) {
119             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantThrow(arg));
120         }
121     }
122 
123     @CodeReflection
124     private static String caseConstantThrow(Integer i) {
125         String r = "";
126         switch (i) {
127             case 8 -> throw new IllegalArgumentException();
128             case 9 -> r += "Nine";
129             default -> r += "An integer";
130         }
131         return r;
132     }
133 
134     @Test
135     void testCaseConstantNullLabel() {
136         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
137         String[] args = {null, "non null"};
138         for (String arg : args) {
139             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantNullLabel(arg));
140         }
141     }
142 
143     @CodeReflection
144     private static String caseConstantNullLabel(String s) {
145         String r = "";
146         switch (s) {
147             case null -> r += "null";
148             default -> r += "non null";
149         }
150         return r;
151     }
152 
153     @Test
154     void testCaseConstantFallThrough() {
155         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
156         char[] args = {'A', 'B', 'C'};
157         for (char arg : args) {
158             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantFallThrough(arg));
159         }
160     }
161 
162     @CodeReflection
163     private static String caseConstantFallThrough(char c) {
164         String r = "";
165         switch (c) {
166             case 'A':
167             case 'B':
168                 r += "A or B";
169                 break;
170             default:
171                 r += "Neither A nor B";
172         }
173         return r;
174     }
175 
176     @Test
177     void testCaseConstantEnum() {
178         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
179         for (Day day : Day.values()) {
180             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, day), caseConstantEnum(day));
181         }
182     }
183 
184     enum Day {
185         MON, TUE, WED, THU, FRI, SAT, SUN
186     }
187     @CodeReflection
188     private static String caseConstantEnum(Day d) {
189         String r = "";
190         switch (d) {
191             case MON, FRI, SUN -> r += 6;
192             case TUE -> r += 7;
193             case THU, SAT -> r += 8;
194             case WED -> r += 9;
195         }
196         return r;
197     }
198 
199     @Test
200     void testCaseConstantOtherKindsOfExpr() {
201         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
202         for (int i = 0; i < 14; i++) {
203             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantOtherKindsOfExpr(i));
204         }
205     }
206 
207     static class Constants {
208         static final int c1 = 12;
209     }
210     @CodeReflection
211     private static String caseConstantOtherKindsOfExpr(int i) {
212         String r = "";
213         final int eleven = 11;
214         switch (i) {
215             case 1 & 0xF -> r += 1;
216             case 4>>1 -> r += "2";
217             case (int) 3L -> r += 3;
218             case 2<<1 -> r += 4;
219             case 10 / 2 -> r += 5;
220             case 12 - 6 -> r += 6;
221             case 3 + 4 -> r += 7;
222             case 2 * 2 * 2 -> r += 8;
223             case 8 | 1 -> r += 9;
224             case (10) -> r += 10;
225             case eleven -> r += 11;
226             case Constants.c1 -> r += Constants.c1;
227             case 1 > 0 ? 13 : 133 -> r += 13;
228             default -> r += "an int";
229         }
230         return r;
231     }
232 
233     @Test
234     void testCaseConstantConv() {
235         CoreOp.FuncOp lmodel = lower("caseConstantConv");
236         for (short i = 1; i < 5; i++) {
237             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), caseConstantConv(i));
238         }
239     }
240 
241     @CodeReflection
242     static String caseConstantConv(short a) {
243         final short s = 1;
244         final byte b = 2;
245         String r = "";
246         switch (a) {
247             case s -> r += "one"; // identity, short -> short
248             case b -> r += "two"; // widening primitive conversion, byte -> short
249             case 3 -> r += "three"; // narrowing primitive conversion, int -> short
250             default -> r += "else";
251         }
252         return r;
253     }
254 
255     @Test
256     void testCaseConstantConv2() {
257         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
258         Byte[] args = {1, 2, 3};
259         for (Byte arg : args) {
260             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseConstantConv2(arg));
261         }
262     }
263 
264     @CodeReflection
265     static String caseConstantConv2(Byte a) {
266         final byte b = 2;
267         String r = "";
268         switch (a) {
269             case 1 -> r+= "one"; // narrowing primitive conversion followed by a boxing conversion, int -> bye -> Byte
270             case b -> r+= "two"; // boxing, byte -> Byte
271             default -> r+= "default";
272         }
273         return r;
274     }
275 
276     @Test
277     void testNonEnhancedSwStatNoDefault() {
278         CoreOp.FuncOp lmodel = lower("nonEnhancedSwStatNoDefault");
279         for (int i = 1; i < 4; i++) {
280             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, i), nonEnhancedSwStatNoDefault(i));
281         }
282     }
283 
284     @CodeReflection
285     static String nonEnhancedSwStatNoDefault(int a) {
286         String r = "";
287         switch (a) {
288             case 1 -> r += "1";
289             case 2 -> r += 2;
290         }
291         return r;
292     }
293 
294     // no reason to test enhanced switch statement that has no default
295     // because we can't test for MatchException without separate compilation
296 
297     @Test
298     void testEnhancedSwStatUnconditionalPattern() {
299         CoreOp.FuncOp lmodel = lower("enhancedSwStatUnconditionalPattern");
300         String[] args = {"A", "B"};
301         for (String arg : args) {
302             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), enhancedSwStatUnconditionalPattern(arg));
303         }
304     }
305 
306     @CodeReflection
307     static String enhancedSwStatUnconditionalPattern(String s) {
308         String r = "";
309         switch (s) {
310             case "A" -> r += "A";
311             case Object o -> r += "obj";
312         }
313         return r;
314     }
315 
316     @Test
317     void testCasePatternBehaviorIsSyntaxIndependent() {
318         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
319         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
320         CoreOp.FuncOp statement = lower("casePatternStatement");
321 
322         Object[] args = {1, "2", 3L};
323 
324         for (Object arg : args) {
325             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg));
326             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg), Interpreter.invoke(MethodHandles.lookup(), statement, arg));
327         }
328     }
329 
330     @CodeReflection
331     private static String casePatternRuleExpression(Object o) {
332         String r = "";
333         switch (o) {
334             case Integer i -> r += "integer";
335             case String s -> r+= "string";
336             default -> r+= "else";
337         }
338         return r;
339     }
340 
341     @CodeReflection
342     private static String casePatternRuleBlock(Object o) {
343         String r = "";
344         switch (o) {
345             case Integer i -> {
346                 r += "integer";
347             }
348             case String s -> {
349                 r += "string";
350             }
351             default -> {
352                 r += "else";
353             }
354         }
355         return r;
356     }
357 
358     @CodeReflection
359     private static String casePatternStatement(Object o) {
360         String r = "";
361         switch (o) {
362             case Integer i:
363                 r += "integer";
364                 break;
365             case String s:
366                 r += "string";
367                 break;
368             default:
369                 r += "else";
370         }
371         return r;
372     }
373 
374     @Test
375     void testCasePatternThrow() {
376         CoreOp.FuncOp lmodel = lower("casePatternThrow");
377 
378         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
379         for (Object arg : args) {
380             Assert.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
381         }
382 
383         Object[] args2 = {"abc", List.of()};
384         for (Object arg : args2) {
385             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternThrow(arg));
386         }
387     }
388 
389     @CodeReflection
390     private static String casePatternThrow(Object o) {
391         String r = "";
392         switch (o) {
393             case Number n -> throw new IllegalArgumentException();
394             case String s -> r += "a string";
395             default -> r += o.getClass().getName();
396         }
397         return r;
398     }
399 
400     // @@@ when multi patterns is supported, we will test it
401 
402     @Test
403     void testCasePatternWithCaseConstant() {
404         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
405         int[] args = {42, 43, -44, 0};
406         for (int arg : args) {
407             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternWithCaseConstant(arg));
408         }
409     }
410 
411     @CodeReflection
412     static String casePatternWithCaseConstant(Integer a) {
413         String r = "";
414         switch (a) {
415             case 42 -> r += "forty two";
416             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
417             case Integer i when i > 0 -> r += "positive int";
418             case Integer i when i < 0 -> r += "negative int";
419             default -> r += "zero";
420         }
421         return r;
422     }
423 
424     @Test
425     void testCaseTypePattern() {
426         CoreOp.FuncOp lmodel = lower("caseTypePattern");
427         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
428         for (Object arg : args) {
429             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseTypePattern(arg));
430         }
431     }
432 
433     @CodeReflection
434     static String caseTypePattern(Object o) {
435         String r = "";
436         switch (o) {
437             case String _ -> r+= "String"; // class
438             case RandomAccess _ -> r+= "RandomAccess"; // interface
439             case int[] _ -> r+= "int[]"; // array primitive
440             case Stack[][] _ -> r+= "Stack[][]"; // array class
441             case Collection[][][] _ -> r+= "Collection[][][]"; // array interface
442             case final Number n -> r+= "Number"; // final modifier
443             default -> r+= "something else";
444         }
445         return r;
446     }
447 
448     @Test
449     void testCaseRecordPattern() {
450         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
451         // @@@ test with generic record
452         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
453         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
454         for (Object arg : args) {
455             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), caseRecordPattern(arg));
456         }
457     }
458 
459     record R(Number n) {}
460     @CodeReflection
461     static String caseRecordPattern(Object o) {
462         String r = "";
463         switch (o) {
464             case R(Number n) -> r += "R(_)";
465             default -> r+= "else";
466         }
467         return r;
468     }
469 
470     @Test
471     void testCasePatternGuard() {
472         CoreOp.FuncOp lmodel = lower("casePatternGuard");
473         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
474         for (Object arg : args) {
475             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), casePatternGuard(arg));
476         }
477     }
478 
479     @CodeReflection
480     static String casePatternGuard(Object obj) {
481         String r = "";
482         switch (obj) {
483             case String s when s.length() > 3 -> r += "str with length > %d".formatted(s.length());
484             case R(Number n) when n.getClass().equals(Double.class) -> r += "R(Double)";
485             default -> r += "else";
486         }
487         return r;
488     }
489 
490     @Test
491     void testDefaultCaseNotTheLast() {
492         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
493         String[] args = {"something", "M", "A"};
494         for (String arg : args) {
495             Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), lmodel, arg), defaultCaseNotTheLast(arg));
496         }
497     }
498 
499     @CodeReflection
500     static String defaultCaseNotTheLast(String s) {
501         String r = "";
502         switch (s) {
503             default -> r += "else";
504             case "M" -> r += "Mow";
505             case "A" -> r += "Aow";
506         }
507         return r;
508     }
509 
510     private static CoreOp.FuncOp lower(String methodName) {
511         return lower(getCodeModel(methodName));
512     }
513 
514     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
515         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
516 
517         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
518         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
519 
520         return lf;
521     }
522 
523     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
524         StringWriter sw = new StringWriter();
525         new OpWriter(sw, options).writeOp(f);
526         try {
527             os.write(sw.toString().getBytes());
528         } catch (IOException e) {
529             throw new RuntimeException(e);
530         }
531     }
532 
533     private static CoreOp.FuncOp getCodeModel(String methodName) {
534         Optional<Method> om = Stream.of(TestSwitchStatementOp.class.getDeclaredMethods())
535                 .filter(m -> m.getName().equals(methodName))
536                 .findFirst();
537 
538         return CoreOp.ofMethod(om.get()).get();
539     }
540 }