1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.Reflect;
 25 import jdk.incubator.code.CodeTransformer;
 26 import jdk.incubator.code.Op;
 27 import jdk.incubator.code.analysis.SSA;
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 import jdk.incubator.code.interpreter.Interpreter;
 30 import org.junit.jupiter.api.Assertions;
 31 import org.junit.jupiter.api.Test;
 32 
 33 import java.lang.invoke.MethodHandles;
 34 import java.lang.reflect.Method;
 35 import java.util.Optional;
 36 import java.util.function.IntSupplier;
 37 import java.util.stream.Stream;
 38 
 39 /*
 40  * @test
 41  * @modules jdk.incubator.code
 42  * @run junit TestSSA
 43  * @run main Unreflect TestSSA
 44  * @run junit TestSSA
 45  * @run junit/othervm -Dbabylon.ssa=cytron TestSSA
 46  */
 47 
 48 public class TestSSA {
 49 
 50     @Reflect
 51     static int ifelse(int a, int b, int n) {
 52         if (n < 10) {
 53             a += 1;
 54         } else {
 55             b += 2;
 56         }
 57         return a + b;
 58     }
 59 
 60     @Test
 61     public void testIfelse() throws Throwable {
 62         CoreOp.FuncOp f = getFuncOp("ifelse");
 63 
 64         CoreOp.FuncOp lf = generate(f);
 65 
 66         Assertions.assertEquals(ifelse(0, 0, 1), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 1));
 67         Assertions.assertEquals(ifelse(0, 0, 11), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 11));
 68     }
 69 
 70     @Reflect
 71     static int ifelseNested(int a, int b, int c, int d, int n) {
 72         if (n < 20) {
 73             if (n < 10) {
 74                 a += 1;
 75             } else {
 76                 b += 2;
 77             }
 78             c += 3;
 79         } else {
 80             if (n > 20) {
 81                 a += 4;
 82             } else {
 83                 b += 5;
 84             }
 85             d += 6;
 86         }
 87         return a + b + c + d;
 88     }
 89 
 90     @Test
 91     public void testIfelseNested() throws Throwable {
 92         CoreOp.FuncOp f = getFuncOp("ifelseNested");
 93 
 94         CoreOp.FuncOp lf = generate(f);
 95 
 96         for (int i : new int[]{1, 11, 20, 21}) {
 97             Assertions.assertEquals(ifelseNested(0, 0, 0, 0, i), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 0, 0, i));
 98         }
 99     }
100 
101     @Reflect
102     static int loop(int n) {
103         int sum = 0;
104         for (int i = 0; i < n; i++) {
105             sum = sum + i;
106         }
107         return sum;
108     }
109 
110     @Test
111     public void testLoop() throws Throwable {
112         CoreOp.FuncOp f = getFuncOp("loop");
113 
114         CoreOp.FuncOp lf = generate(f);
115 
116         Assertions.assertEquals(loop(10), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 10));
117     }
118 
119     @Reflect
120     static int nestedLoop(int n) {
121         int sum = 0;
122         for (int i = 0; i < n; i++) {
123             for (int j = 0; j < n; j++) {
124                 sum = sum + i + j;
125             }
126         }
127         return sum;
128     }
129 
130     @Test
131     public void testNestedLoop() {
132         CoreOp.FuncOp f = getFuncOp("nestedLoop");
133 
134         CoreOp.FuncOp lf = generate(f);
135 
136         Assertions.assertEquals(nestedLoop(10), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 10));
137     }
138 
139     @Reflect
140     static int nestedLambdaCapture(int i) {
141         IntSupplier s = () -> {
142             int j = i + 1;
143             IntSupplier s2 = () -> i + j;
144             return s2.getAsInt() + i;
145         };
146         return s.getAsInt();
147     }
148 
149     @Test
150     public void testNestedLambdaCapture() {
151         CoreOp.FuncOp f = getFuncOp("nestedLambdaCapture");
152 
153         CoreOp.FuncOp lf = generate(f);
154 
155         Assertions.assertEquals(nestedLambdaCapture(10), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 10));
156     }
157 
158     @Reflect
159     static int deadCode(int n) {
160         int factorial = 1;
161         int unused = factorial;
162         int unusedLoop = 0;
163         for (int i = 0; i < 4; i++) {
164             unusedLoop++;
165         }
166         while (n > 0) {
167             factorial *= n;
168             n--;
169             if (factorial == 0) {
170                 factorial = -1;
171                 int unusedNested = factorial;
172             }
173         }
174         return factorial;
175     }
176 
177     @Test
178     public void testDeadCode() {
179         CoreOp.FuncOp f = getFuncOp("deadCode");
180 
181         CoreOp.FuncOp lf = generate(f);
182 
183         Assertions.assertEquals(deadCode(10), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 10));
184     }
185 
186     @Reflect
187     static int ifelseLoopNested(int n) {
188         int counter = 10;
189         while (n > 0) {
190             if (n % 2 == 0) {
191                 int sum = n;
192                 for (int i = 0; i < 5; i++) {
193                     if (sum > n / 2) {
194                         sum -= i;
195                     } else {
196                         sum += i;
197                         break;
198                     }
199                 }
200                 n += sum;
201             } else {
202                 int difference = (n % 3 == 0) ? n / 2 : n * 2;
203                 n -= difference;
204             }
205             counter--;
206         }
207         return n;
208     }
209 
210     @Test
211     public void testIfelseLoopNested() {
212         CoreOp.FuncOp f = getFuncOp("ifelseLoopNested");
213 
214         CoreOp.FuncOp lf = generate(f);
215 
216         Assertions.assertEquals(ifelseLoopNested(10), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 10));
217     }
218 
219     @Reflect
220     static int violaJones(int x, int maxX, int length, int integral) {
221         int scale = 0;
222         scale++;
223         while (x > scale && scale < length) {
224         }
225         for (int i = 0; i < integral; i++) {
226             scale--;
227         }
228         return scale;
229     }
230 
231     @Test
232     public void testViolaJones() {
233         CoreOp.FuncOp f = getFuncOp("violaJones");
234 
235         CoreOp.FuncOp lf = generate(f);
236 
237         Assertions.assertEquals(violaJones(0, 1, 0, 0), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 1, 0, 0));
238     }
239 
240     @Reflect
241     static int violaJonesTwo(int x, int maxX, int length, int integral) {
242         int scale = 0, scale_extra = 1;
243         scale++;
244         int j = 9;
245         while (x > scale && scale < length) {
246             j = 1;
247         }
248         for (int i = 0; i < integral; i++) {
249             scale--;
250         }
251         return scale + scale_extra + j;
252     }
253 
254     @Test
255     public void testViolaJonesTwo() {
256         CoreOp.FuncOp f = getFuncOp("violaJonesTwo");
257 
258         CoreOp.FuncOp lf = generate(f);
259 
260         Assertions.assertEquals(violaJonesTwo(0, 1, 0, 0), (int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 1, 0, 0));
261     }
262 
263     @Reflect
264     static boolean binarySearch(int[] arr, int target) {
265         int l = 0;
266         int r = arr.length - 1;
267         while (l < r) {
268             int m = (r - l) / 2;
269             m += l;
270             if (arr[m] < target) {
271                 l = m + 1;
272             } else if (arr[m] > target) {
273                 r = m - 1;
274             } else {
275                 return true;
276             }
277         }
278         return false;
279     }
280 
281     @Test
282     public void testBinarySearch() {
283         CoreOp.FuncOp f = getFuncOp("binarySearch");
284 
285         CoreOp.FuncOp lf = generate(f);
286 
287         int[] arr = new int[]{1, 2, 4, 7, 11, 19, 21, 29, 30, 36};
288 
289         Assertions.assertEquals(binarySearch(arr, 4), (boolean) Interpreter.invoke(MethodHandles.lookup(), lf, arr, 4));
290     }
291 
292     @Reflect
293     static void quicksort(int[] arr, int lo, int hi) {
294         if (lo >= hi || lo < 0) {
295             return;
296         }
297 
298         int pivot = arr[hi];
299         int i = lo;
300         for (int j = lo; j < hi; j++) {
301             if (arr[j] <= pivot) {
302                 int temp = arr[i];
303                 arr[i] = arr[j];
304                 arr[j] = temp;
305                 i++;
306             }
307         }
308         int temp = arr[i];
309         arr[i] = arr[hi];
310         arr[hi] = temp;
311 
312         quicksort(arr, lo, i - 1);
313         quicksort(arr, i + 1, hi);
314     }
315 
316     @Test
317     public void testQuicksort() {
318         CoreOp.FuncOp f = getFuncOp("quicksort");
319 
320         CoreOp.FuncOp lf = generate(f);
321 
322         int[] arr1 = new int[]{5, 2, 7, 45, 34, 14, 0, 27, 43, 11, 38, 56, 81};
323         int[] arr2 = new int[]{2, 11, 45, 34, 0, 27, 38, 56, 7, 43, 14, 5, 81};
324 
325         Interpreter.invoke(MethodHandles.lookup(), lf, arr1, 0, arr1.length - 1);
326         quicksort(arr2, 0, arr2.length - 1);
327         Assertions.assertArrayEquals(arr2, arr1);
328     }
329 
330     static CoreOp.FuncOp generate(CoreOp.FuncOp f) {
331         System.out.println(f.toText());
332 
333         CoreOp.FuncOp lf = f.transform(CodeTransformer.LOWERING_TRANSFORMER);
334         System.out.println(lf.toText());
335 
336         lf = SSA.transform(lf);
337         System.out.println(lf.toText());
338         return lf;
339     }
340 
341     static CoreOp.FuncOp getFuncOp(String name) {
342         Optional<Method> om = Stream.of(TestSSA.class.getDeclaredMethods())
343                 .filter(m -> m.getName().equals(name))
344                 .findFirst();
345 
346         Method m = om.get();
347         return Op.ofMethod(m).get();
348     }
349 }