1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import java.lang.invoke.MethodHandles;
 28 import jdk.incubator.code.OpTransformer;
 29 import jdk.incubator.code.dialect.core.CoreOp;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.analysis.SSA;
 32 import jdk.incubator.code.interpreter.Interpreter;
 33 import java.lang.reflect.Method;
 34 import jdk.incubator.code.CodeReflection;
 35 import java.util.Optional;
 36 import java.util.function.IntSupplier;
 37 import java.util.stream.Stream;
 38 
 39 /*
 40  * @test
 41  * @modules jdk.incubator.code
 42  * @run testng TestSSA
 43  * @run testng/othervm -Dbabylon.ssa=cytron TestSSA
 44  */
 45 
 46 public class TestSSA {
 47 
 48     @CodeReflection
 49     static int ifelse(int a, int b, int n) {
 50         if (n < 10) {
 51             a += 1;
 52         } else {
 53             b += 2;
 54         }
 55         return a + b;
 56     }
 57 
 58     @Test
 59     public void testIfelse() throws Throwable {
 60         CoreOp.FuncOp f = getFuncOp("ifelse");
 61 
 62         CoreOp.FuncOp lf = generate(f);
 63 
 64         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 1), ifelse(0, 0, 1));
 65         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 11), ifelse(0, 0, 11));
 66     }
 67 
 68     @CodeReflection
 69     static int ifelseNested(int a, int b, int c, int d, int n) {
 70         if (n < 20) {
 71             if (n < 10) {
 72                 a += 1;
 73             } else {
 74                 b += 2;
 75             }
 76             c += 3;
 77         } else {
 78             if (n > 20) {
 79                 a += 4;
 80             } else {
 81                 b += 5;
 82             }
 83             d += 6;
 84         }
 85         return a + b + c + d;
 86     }
 87 
 88     @Test
 89     public void testIfelseNested() throws Throwable {
 90         CoreOp.FuncOp f = getFuncOp("ifelseNested");
 91 
 92         CoreOp.FuncOp lf = generate(f);
 93 
 94         for (int i : new int[]{1, 11, 20, 21}) {
 95             Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 0, 0, 0, i), ifelseNested(0, 0, 0, 0, i));
 96         }
 97     }
 98 
 99     @CodeReflection
100     static int loop(int n) {
101         int sum = 0;
102         for (int i = 0; i < n; i++) {
103             sum = sum + i;
104         }
105         return sum;
106     }
107 
108     @Test
109     public void testLoop() throws Throwable {
110         CoreOp.FuncOp f = getFuncOp("loop");
111 
112         CoreOp.FuncOp lf = generate(f);
113 
114         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 10), loop(10));
115     }
116 
117     @CodeReflection
118     static int nestedLoop(int n) {
119         int sum = 0;
120         for (int i = 0; i < n; i++) {
121             for (int j = 0; j < n; j++) {
122                 sum = sum + i + j;
123             }
124         }
125         return sum;
126     }
127 
128     @Test
129     public void testNestedLoop() {
130         CoreOp.FuncOp f = getFuncOp("nestedLoop");
131 
132         CoreOp.FuncOp lf = generate(f);
133 
134         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 10), nestedLoop(10));
135     }
136 
137     @CodeReflection
138     static int nestedLambdaCapture(int i) {
139         IntSupplier s = () -> {
140             int j = i + 1;
141             IntSupplier s2 = () -> i + j;
142             return s2.getAsInt() + i;
143         };
144         return s.getAsInt();
145     }
146 
147     @Test
148     public void testNestedLambdaCapture() {
149         CoreOp.FuncOp f = getFuncOp("nestedLambdaCapture");
150 
151         CoreOp.FuncOp lf = generate(f);
152 
153         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 10), nestedLambdaCapture(10));
154     }
155 
156     @CodeReflection
157     static int deadCode(int n) {
158         int factorial = 1;
159         int unused = factorial;
160         int unusedLoop = 0;
161         for (int i = 0; i < 4; i++) {
162             unusedLoop++;
163         }
164         while (n > 0) {
165             factorial *= n;
166             n--;
167             if (factorial == 0) {
168                 factorial = -1;
169                 int unusedNested = factorial;
170             }
171         }
172         return factorial;
173     }
174 
175     @Test
176     public void testDeadCode() {
177         CoreOp.FuncOp f = getFuncOp("deadCode");
178 
179         CoreOp.FuncOp lf = generate(f);
180 
181         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 10), deadCode(10));
182     }
183 
184     @CodeReflection
185     static int ifelseLoopNested(int n) {
186         int counter = 10;
187         while (n > 0) {
188             if (n % 2 == 0) {
189                 int sum = n;
190                 for (int i = 0; i < 5; i++) {
191                     if (sum > n / 2) {
192                         sum -= i;
193                     } else {
194                         sum += i;
195                         break;
196                     }
197                 }
198                 n += sum;
199             } else {
200                 int difference = (n % 3 == 0) ? n / 2 : n * 2;
201                 n -= difference;
202             }
203             counter--;
204         }
205         return n;
206     }
207 
208     @Test
209     public void testIfelseLoopNested() {
210         CoreOp.FuncOp f = getFuncOp("ifelseLoopNested");
211 
212         CoreOp.FuncOp lf = generate(f);
213 
214         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 10), ifelseLoopNested(10));
215     }
216 
217     @CodeReflection
218     static int violaJones(int x, int maxX, int length, int integral) {
219         int scale = 0;
220         scale++;
221         while (x > scale && scale < length) {
222         }
223         for (int i = 0; i < integral; i++) {
224             scale--;
225         }
226         return scale;
227     }
228 
229     @Test
230     public void testViolaJones() {
231         CoreOp.FuncOp f = getFuncOp("violaJones");
232 
233         CoreOp.FuncOp lf = generate(f);
234 
235         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 1, 0, 0), violaJones(0, 1, 0, 0));
236     }
237 
238     @CodeReflection
239     static int violaJonesTwo(int x, int maxX, int length, int integral) {
240         int scale = 0, scale_extra = 1;
241         scale++;
242         int j = 9;
243         while (x > scale && scale < length) {
244             j = 1;
245         }
246         for (int i = 0; i < integral; i++) {
247             scale--;
248         }
249         return scale + scale_extra + j;
250     }
251 
252     @Test
253     public void testViolaJonesTwo() {
254         CoreOp.FuncOp f = getFuncOp("violaJonesTwo");
255 
256         CoreOp.FuncOp lf = generate(f);
257 
258         Assert.assertEquals((int) Interpreter.invoke(MethodHandles.lookup(), lf, 0, 1, 0, 0), violaJonesTwo(0, 1, 0, 0));
259     }
260 
261     @CodeReflection
262     static boolean binarySearch(int[] arr, int target) {
263         int l = 0;
264         int r = arr.length - 1;
265         while (l < r) {
266             int m = (r - l) / 2;
267             m += l;
268             if (arr[m] < target) {
269                 l = m + 1;
270             } else if (arr[m] > target) {
271                 r = m - 1;
272             } else {
273                 return true;
274             }
275         }
276         return false;
277     }
278 
279     @Test
280     public void testBinarySearch() {
281         CoreOp.FuncOp f = getFuncOp("binarySearch");
282 
283         CoreOp.FuncOp lf = generate(f);
284 
285         int[] arr = new int[]{1, 2, 4, 7, 11, 19, 21, 29, 30, 36};
286 
287         Assert.assertEquals((boolean) Interpreter.invoke(MethodHandles.lookup(), lf, arr, 4), binarySearch(arr, 4));
288     }
289 
290     @CodeReflection
291     static void quicksort(int[] arr, int lo, int hi) {
292         if (lo >= hi || lo < 0) {
293             return;
294         }
295 
296         int pivot = arr[hi];
297         int i = lo;
298         for (int j = lo; j < hi; j++) {
299             if (arr[j] <= pivot) {
300                 int temp = arr[i];
301                 arr[i] = arr[j];
302                 arr[j] = temp;
303                 i++;
304             }
305         }
306         int temp = arr[i];
307         arr[i] = arr[hi];
308         arr[hi] = temp;
309 
310         quicksort(arr, lo, i - 1);
311         quicksort(arr, i + 1, hi);
312     }
313 
314     @Test
315     public void testQuicksort() {
316         CoreOp.FuncOp f = getFuncOp("quicksort");
317 
318         CoreOp.FuncOp lf = generate(f);
319 
320         int[] arr1 = new int[]{5, 2, 7, 45, 34, 14, 0, 27, 43, 11, 38, 56, 81};
321         int[] arr2 = new int[]{2, 11, 45, 34, 0, 27, 38, 56, 7, 43, 14, 5, 81};
322 
323         Interpreter.invoke(MethodHandles.lookup(), lf, arr1, 0, arr1.length - 1);
324         quicksort(arr2, 0, arr2.length - 1);
325         Assert.assertEquals(arr1, arr2);
326     }
327 
328     static CoreOp.FuncOp generate(CoreOp.FuncOp f) {
329         System.out.println(f.toText());
330 
331         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
332         System.out.println(lf.toText());
333 
334         lf = SSA.transform(lf);
335         System.out.println(lf.toText());
336         return lf;
337     }
338 
339     static CoreOp.FuncOp getFuncOp(String name) {
340         Optional<Method> om = Stream.of(TestSSA.class.getDeclaredMethods())
341                 .filter(m -> m.getName().equals(name))
342                 .findFirst();
343 
344         Method m = om.get();
345         return Op.ofMethod(m).get();
346     }
347 }