1 /*
  2  * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.Reflect;
 25 import jdk.incubator.code.CodeTransformer;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.extern.OpWriter;
 28 import org.junit.jupiter.api.Assertions;
 29 import org.junit.jupiter.api.Test;
 30 
 31 import java.io.IOException;
 32 import java.io.OutputStream;
 33 import java.io.StringWriter;
 34 import java.lang.invoke.MethodHandles;
 35 import java.lang.reflect.Method;
 36 import java.util.*;
 37 import java.util.stream.Stream;
 38 
 39 /*
 40  * @test
 41  * @modules jdk.incubator.code
 42  * @library lib
 43  * @run junit TestSwitchStatementOp
 44  * @run main Unreflect TestSwitchStatementOp
 45  * @run junit TestSwitchStatementOp
 46  *
 47  */
 48 public class TestSwitchStatementOp {
 49 
 50     @Test
 51     void testCaseConstantBehaviorIsSyntaxIndependent() {
 52         CoreOp.FuncOp ruleExpression = lower("caseConstantRuleExpression");
 53         CoreOp.FuncOp ruleBlock = lower("caseConstantRuleBlock");
 54         CoreOp.FuncOp statement = lower("caseConstantStatement");
 55 
 56         String[] args = {"FOO", "BAR", "BAZ", "OTHER"};
 57 
 58         for (String arg : args) {
 59             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
 60             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
 61         }
 62     }
 63 
 64     @Reflect
 65     public static String caseConstantRuleExpression(String r) {
 66         String s = "";
 67         switch (r) {
 68             case "FOO" -> s += "BAR";
 69             case "BAR" -> s += "BAZ";
 70             case "BAZ" -> s += "FOO";
 71             default -> s += "else";
 72         }
 73         return s;
 74     }
 75 
 76     @Reflect
 77     public static String caseConstantRuleBlock(String r) {
 78         String s = "";
 79         switch (r) {
 80             case "FOO" -> {
 81                 s += "BAR";
 82             }
 83             case "BAR" -> {
 84                 s += "BAZ";
 85             }
 86             case "BAZ" -> {
 87                 s += "FOO";
 88             }
 89             default -> {
 90                 s += "else";
 91             }
 92         }
 93         return s;
 94     }
 95 
 96     @Reflect
 97     private static String caseConstantStatement(String s) {
 98         String r = "";
 99         switch (s) {
100             case "FOO":
101                 r += "BAR";
102                 break;
103             case "BAR":
104                 r += "BAZ";
105                 break;
106             case "BAZ":
107                 r += "FOO";
108                 break;
109             default:
110                 r += "else";
111         }
112         return r;
113     }
114 
115     @Test
116     void testCaseConstantMultiLabels() {
117         CoreOp.FuncOp lmodel = lower("caseConstantMultiLabels");
118         char[] args = {'a', 'e', 'i', 'o', 'u', 'j', 'p', 'g'};
119         for (char arg : args) {
120             Assertions.assertEquals(caseConstantMultiLabels(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
121         }
122     }
123 
124     @Reflect
125     private static String caseConstantMultiLabels(char c) {
126         String r = "";
127         switch (Character.toLowerCase(c)) {
128             case 'a', 'e', 'i', 'o', 'u':
129                 r += "vowel";
130                 break;
131             default:
132                 r += "consonant";
133         }
134         return r;
135     }
136 
137     @Test
138     void testCaseConstantThrow() {
139         CoreOp.FuncOp lmodel = lower("caseConstantThrow");
140 
141         Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, 8));
142 
143         int[] args = {9, 10};
144         for (int arg : args) {
145             Assertions.assertEquals(caseConstantThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
146         }
147     }
148 
149     @Reflect
150     private static String caseConstantThrow(Integer i) {
151         String r = "";
152         switch (i) {
153             case 8 -> throw new IllegalArgumentException();
154             case 9 -> r += "Nine";
155             default -> r += "An integer";
156         }
157         return r;
158     }
159 
160     @Test
161     void testCaseConstantNullLabel() {
162         CoreOp.FuncOp lmodel = lower("caseConstantNullLabel");
163         String[] args = {null, "non null"};
164         for (String arg : args) {
165             Assertions.assertEquals(caseConstantNullLabel(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
166         }
167     }
168 
169     @Reflect
170     private static String caseConstantNullLabel(String s) {
171         String r = "";
172         switch (s) {
173             case null -> r += "null";
174             default -> r += "non null";
175         }
176         return r;
177     }
178 
179     @Test
180     void testCaseConstantFallThrough() {
181         CoreOp.FuncOp lmodel = lower("caseConstantFallThrough");
182         char[] args = {'A', 'B', 'C'};
183         for (char arg : args) {
184             Assertions.assertEquals(caseConstantFallThrough(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
185         }
186     }
187 
188     @Reflect
189     private static String caseConstantFallThrough(char c) {
190         String r = "";
191         switch (c) {
192             case 'A':
193             case 'B':
194                 r += "A or B";
195                 break;
196             default:
197                 r += "Neither A nor B";
198         }
199         return r;
200     }
201 
202     @Test
203     void testCaseConstantEnum() {
204         CoreOp.FuncOp lmodel = lower("caseConstantEnum");
205         for (Day day : Day.values()) {
206             Assertions.assertEquals(caseConstantEnum(day), Interpreter.invoke(MethodHandles.lookup(), lmodel, day));
207         }
208     }
209 
210     enum Day {
211         MON, TUE, WED, THU, FRI, SAT, SUN
212     }
213     @Reflect
214     private static String caseConstantEnum(Day d) {
215         String r = "";
216         switch (d) {
217             case MON, FRI, SUN -> r += 6;
218             case TUE -> r += 7;
219             case THU, SAT -> r += 8;
220             case WED -> r += 9;
221         }
222         return r;
223     }
224 
225     @Test
226     void testCaseConstantOtherKindsOfExpr() {
227         CoreOp.FuncOp lmodel = lower("caseConstantOtherKindsOfExpr");
228         for (int i = 0; i < 14; i++) {
229             Assertions.assertEquals(caseConstantOtherKindsOfExpr(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
230         }
231     }
232 
233     static class Constants {
234         static final int c1 = 12;
235     }
236     @Reflect
237     private static String caseConstantOtherKindsOfExpr(int i) {
238         String r = "";
239         final int eleven = 11;
240         switch (i) {
241             case 1 & 0xF -> r += 1;
242             case 4>>1 -> r += "2";
243             case (int) 3L -> r += 3;
244             case 2<<1 -> r += 4;
245             case 10 / 2 -> r += 5;
246             case 12 - 6 -> r += 6;
247             case 3 + 4 -> r += 7;
248             case 2 * 2 * 2 -> r += 8;
249             case 8 | 1 -> r += 9;
250             case (10) -> r += 10;
251             case eleven -> r += 11;
252             case Constants.c1 -> r += Constants.c1;
253             case 1 > 0 ? 13 : 133 -> r += 13;
254             default -> r += "an int";
255         }
256         return r;
257     }
258 
259     @Test
260     void testCaseConstantConv() {
261         CoreOp.FuncOp lmodel = lower("caseConstantConv");
262         for (short i = 1; i < 5; i++) {
263             Assertions.assertEquals(caseConstantConv(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
264         }
265     }
266 
267     @Reflect
268     static String caseConstantConv(short a) {
269         final short s = 1;
270         final byte b = 2;
271         String r = "";
272         switch (a) {
273             case s -> r += "one"; // identity, short -> short
274             case b -> r += "two"; // widening primitive conversion, byte -> short
275             case 3 -> r += "three"; // narrowing primitive conversion, int -> short
276             default -> r += "else";
277         }
278         return r;
279     }
280 
281     @Test
282     void testCaseConstantConv2() {
283         CoreOp.FuncOp lmodel = lower("caseConstantConv2");
284         Byte[] args = {1, 2, 3};
285         for (Byte arg : args) {
286             Assertions.assertEquals(caseConstantConv2(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
287         }
288     }
289 
290     @Reflect
291     static String caseConstantConv2(Byte a) {
292         final byte b = 2;
293         String r = "";
294         switch (a) {
295             case 1 -> r+= "one"; // narrowing primitive conversion followed by a boxing conversion, int -> bye -> Byte
296             case b -> r+= "two"; // boxing, byte -> Byte
297             default -> r+= "default";
298         }
299         return r;
300     }
301 
302     @Test
303     void testNonEnhancedSwStatNoDefault() {
304         CoreOp.FuncOp lmodel = lower("nonEnhancedSwStatNoDefault");
305         for (int i = 1; i < 4; i++) {
306             Assertions.assertEquals(nonEnhancedSwStatNoDefault(i), Interpreter.invoke(MethodHandles.lookup(), lmodel, i));
307         }
308     }
309 
310     @Reflect
311     static String nonEnhancedSwStatNoDefault(int a) {
312         String r = "";
313         switch (a) {
314             case 1 -> r += "1";
315             case 2 -> r += 2;
316         }
317         return r;
318     }
319 
320     // no reason to test enhanced switch statement that has no default
321     // because we can't test for MatchException without separate compilation
322 
323     @Test
324     void testEnhancedSwStatUnconditionalPattern() {
325         CoreOp.FuncOp lmodel = lower("enhancedSwStatUnconditionalPattern");
326         String[] args = {"A", "B"};
327         for (String arg : args) {
328             Assertions.assertEquals(enhancedSwStatUnconditionalPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
329         }
330     }
331 
332     @Reflect
333     static String enhancedSwStatUnconditionalPattern(String s) {
334         String r = "";
335         switch (s) {
336             case "A" -> r += "A";
337             case Object o -> r += "obj";
338         }
339         return r;
340     }
341 
342     @Test
343     void testCasePatternBehaviorIsSyntaxIndependent() {
344         CoreOp.FuncOp ruleExpression = lower("casePatternRuleExpression");
345         CoreOp.FuncOp ruleBlock = lower("casePatternRuleBlock");
346         CoreOp.FuncOp statement = lower("casePatternStatement");
347 
348         Object[] args = {1, "2", 3L};
349 
350         for (Object arg : args) {
351             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), ruleBlock, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
352             Assertions.assertEquals(Interpreter.invoke(MethodHandles.lookup(), statement, arg), Interpreter.invoke(MethodHandles.lookup(), ruleExpression, arg));
353         }
354     }
355 
356     @Reflect
357     private static String casePatternRuleExpression(Object o) {
358         String r = "";
359         switch (o) {
360             case Integer i -> r += "integer";
361             case String s -> r+= "string";
362             default -> r+= "else";
363         }
364         return r;
365     }
366 
367     @Reflect
368     private static String casePatternRuleBlock(Object o) {
369         String r = "";
370         switch (o) {
371             case Integer i -> {
372                 r += "integer";
373             }
374             case String s -> {
375                 r += "string";
376             }
377             default -> {
378                 r += "else";
379             }
380         }
381         return r;
382     }
383 
384     @Reflect
385     private static String casePatternStatement(Object o) {
386         String r = "";
387         switch (o) {
388             case Integer i:
389                 r += "integer";
390                 break;
391             case String s:
392                 r += "string";
393                 break;
394             default:
395                 r += "else";
396         }
397         return r;
398     }
399 
400     @Test
401     void testCasePatternThrow() {
402         CoreOp.FuncOp lmodel = lower("casePatternThrow");
403 
404         Object[] args = {Byte.MAX_VALUE, Short.MIN_VALUE, 0, 1L, 11f, 22d};
405         for (Object arg : args) {
406             Assertions.assertThrows(IllegalArgumentException.class, () -> Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
407         }
408 
409         Object[] args2 = {"abc", List.of()};
410         for (Object arg : args2) {
411             Assertions.assertEquals(casePatternThrow(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
412         }
413     }
414 
415     @Reflect
416     private static String casePatternThrow(Object o) {
417         String r = "";
418         switch (o) {
419             case Number n -> throw new IllegalArgumentException();
420             case String s -> r += "a string";
421             default -> r += o.getClass().getName();
422         }
423         return r;
424     }
425 
426     // @@@ when multi patterns is supported, we will test it
427 
428     @Test
429     void testCasePatternWithCaseConstant() {
430         CoreOp.FuncOp lmodel = lower("casePatternWithCaseConstant");
431         int[] args = {42, 43, -44, 0};
432         for (int arg : args) {
433             Assertions.assertEquals(casePatternWithCaseConstant(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
434         }
435     }
436 
437     @Reflect
438     static String casePatternWithCaseConstant(Integer a) {
439         String r = "";
440         switch (a) {
441             case 42 -> r += "forty two";
442             // @@@ case int will not match, because of the way InstanceOfOp is interpreted
443             case Integer i when i > 0 -> r += "positive int";
444             case Integer i when i < 0 -> r += "negative int";
445             default -> r += "zero";
446         }
447         return r;
448     }
449 
450     @Test
451     void testCaseTypePattern() {
452         CoreOp.FuncOp lmodel = lower("caseTypePattern");
453         Object[] args = {"str", new ArrayList<>(), new int[]{}, new Stack[][]{}, new Collection[][][]{}, 8, 'x'};
454         for (Object arg : args) {
455             Assertions.assertEquals(caseTypePattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
456         }
457     }
458 
459     @Reflect
460     static String caseTypePattern(Object o) {
461         String r = "";
462         switch (o) {
463             case String _ -> r+= "String"; // class
464             case RandomAccess _ -> r+= "RandomAccess"; // interface
465             case int[] _ -> r+= "int[]"; // array primitive
466             case Stack[][] _ -> r+= "Stack[][]"; // array class
467             case Collection[][][] _ -> r+= "Collection[][][]"; // array interface
468             case final Number n -> r+= "Number"; // final modifier
469             default -> r+= "something else";
470         }
471         return r;
472     }
473 
474     @Test
475     void testCaseRecordPattern() {
476         // @@@ new R(null) must match the pattern R(Number c), but it doesn't
477         // @@@ test with generic record
478         CoreOp.FuncOp lmodel = lower("caseRecordPattern");
479         Object[] args = {new R(8), new R(1.0), new R(2L), "abc"};
480         for (Object arg : args) {
481             Assertions.assertEquals(caseRecordPattern(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
482         }
483     }
484 
485     record R(Number n) {}
486     @Reflect
487     static String caseRecordPattern(Object o) {
488         String r = "";
489         switch (o) {
490             case R(Number n) -> r += "R(_)";
491             default -> r+= "else";
492         }
493         return r;
494     }
495 
496     @Test
497     void testCasePatternGuard() {
498         CoreOp.FuncOp lmodel = lower("casePatternGuard");
499         Object[] args = {"c++", "java", new R(8), new R(2L), new R(3f), new R(4.0)};
500         for (Object arg : args) {
501             Assertions.assertEquals(casePatternGuard(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
502         }
503     }
504 
505     @Reflect
506     static String casePatternGuard(Object obj) {
507         String r = "";
508         switch (obj) {
509             case String s when s.length() > 3 -> r += "str with length > %d".formatted(s.length());
510             case R(Number n) when n.getClass().equals(Double.class) -> r += "R(Double)";
511             default -> r += "else";
512         }
513         return r;
514     }
515 
516     @Test
517     void testDefaultCaseNotTheLast() {
518         CoreOp.FuncOp lmodel = lower("defaultCaseNotTheLast");
519         String[] args = {"something", "M", "A"};
520         for (String arg : args) {
521             Assertions.assertEquals(defaultCaseNotTheLast(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
522         }
523     }
524 
525     @Reflect
526     static String defaultCaseNotTheLast(String s) {
527         String r = "";
528         switch (s) {
529             default -> r += "else";
530             case "M" -> r += "Mow";
531             case "A" -> r += "Aow";
532         }
533         return r;
534     }
535 
536     @Reflect
537     static String caseConstantPrimitiveWrapperSelector(Integer i) {
538         String r = "";
539         switch (i) {
540             case 1 -> r += "one";
541             case 2, 3 -> r += "two or three";
542             default -> r += "else";
543         };
544         return r;
545     }
546 
547     @Test
548     void testCaseConstantPrimitiveWrapperSelector() {
549         CoreOp.FuncOp lf = lower("caseConstantPrimitiveWrapperSelector");
550         Integer[] args = {1, 2, 3, 4};
551         for (Integer a : args) {
552             Assertions.assertEquals(caseConstantPrimitiveWrapperSelector(a),
553                     Interpreter.invoke(MethodHandles.lookup(), lf, a));
554         }
555     }
556 
557     @Reflect
558     static String constantLabelCasted(int i) {
559         String r = "";
560         switch (i) {
561             case (byte) 1 -> r += "one";
562             default -> r += "not one";
563         };
564         return r;
565     }
566 
567     @Test
568     void testConstantLabelCasted() {
569         CoreOp.FuncOp lf = lower("constantLabelCasted");
570         int[] args = {-1, 1};
571         for (int a : args) {
572             Assertions.assertEquals(constantLabelCasted(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
573         }
574     }
575 
576     @Reflect
577     static String caseConstantStringLiteral(String s) {
578         String r = "";
579         switch (s) {
580             case "1" -> r += "one";
581             case "2", "3" -> r+= "two or three";
582             default -> r += "else";
583         };
584         return r;
585     }
586 
587     @Test
588     void testCeaseConstantStringLiteral() {
589         CoreOp.FuncOp lf = lower("caseConstantStringLiteral");
590         String[] args = {"1", "2", "3", ""};
591         for (String a : args) {
592             Assertions.assertEquals(caseConstantStringLiteral(a), Interpreter.invoke(MethodHandles.lookup(), lf, a));
593         }
594     }
595 
596     @Test
597     void testTryAndSwitch() {
598         CoreOp.FuncOp lmodel = lower("tryAndSwitch");
599         String[] args = {"A", "B"};
600         for (String arg : args) {
601             Assertions.assertEquals(tryAndSwitch(arg), Interpreter.invoke(MethodHandles.lookup(), lmodel, arg));
602         }
603     }
604 
605     @Reflect
606     private static List<String> tryAndSwitch(String s) {
607         List<String> r = new ArrayList<>();
608         try {
609             switch (s) {
610                 case "A":
611                     r.add("A");
612                     return r;
613                 default:
614                     r.add("B");
615             }
616         } finally {
617             r.add("finally");
618         }
619         return r;
620     }
621 
622     private static CoreOp.FuncOp lower(String methodName) {
623         return lower(getCodeModel(methodName));
624     }
625 
626     private static CoreOp.FuncOp lower(CoreOp.FuncOp f) {
627         writeModel(f, System.out, OpWriter.LocationOption.DROP_LOCATION);
628 
629         CoreOp.FuncOp lf = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
630         writeModel(lf, System.out, OpWriter.LocationOption.DROP_LOCATION);
631 
632         return lf;
633     }
634 
635     private static void writeModel(CoreOp.FuncOp f, OutputStream os, OpWriter.Option... options) {
636         StringWriter sw = new StringWriter();
637         new OpWriter(sw, options).writeOp(f);
638         try {
639             os.write(sw.toString().getBytes());
640         } catch (IOException e) {
641             throw new RuntimeException(e);
642         }
643     }
644 
645     private static CoreOp.FuncOp getCodeModel(String methodName) {
646         Optional<Method> om = Stream.of(TestSwitchStatementOp.class.getDeclaredMethods())
647                 .filter(m -> m.getName().equals(methodName))
648                 .findFirst();
649 
650         return CoreOp.ofMethod(om.get()).get();
651     }
652 }