1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @modules jdk.incubator.code
 27  * @run junit TestStringConcatTransform
 28  * @run junit/othervm -Dbabylon.ssa=cytron TestStringConcatTransform
 29  */
 30 
 31 import jdk.incubator.code.CodeReflection;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.OpTransformer;
 34 import jdk.incubator.code.analysis.SSA;
 35 import jdk.incubator.code.analysis.StringConcatTransformer;
 36 import jdk.incubator.code.dialect.core.CoreOp;
 37 import jdk.incubator.code.interpreter.Interpreter;
 38 import org.junit.jupiter.api.Assertions;
 39 import org.junit.jupiter.api.Test;
 40 import org.junit.jupiter.params.ParameterizedTest;
 41 import org.junit.jupiter.params.provider.MethodSource;
 42 
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.reflect.InvocationTargetException;
 45 import java.lang.reflect.Method;
 46 import java.util.Arrays;
 47 import java.util.HashMap;
 48 import java.util.Map;
 49 
 50 public class TestStringConcatTransform {
 51 
 52     static final String TESTSTR = "TESTING STRING";
 53 
 54     static final Map<Class<?>, Object> valMap;
 55 
 56     static {
 57         valMap = new HashMap<>();
 58         valMap.put(byte.class, (byte) 42);
 59         valMap.put(short.class, (short) 42);
 60         valMap.put(int.class, 42);
 61         valMap.put(long.class, (long) 42);
 62         valMap.put(float.class, 42f);
 63         valMap.put(double.class, 42d);
 64         valMap.put(char.class, 'z');
 65         valMap.put(boolean.class, false);
 66 
 67         valMap.put(Byte.class, (byte) 42);
 68         valMap.put(Short.class, (short) 42);
 69         valMap.put(Integer.class, 42);
 70         valMap.put(Long.class, (long) 42);
 71         valMap.put(Float.class, 42f);
 72         valMap.put(Double.class, 42d);
 73         valMap.put(Character.class, 'z');
 74         valMap.put(Boolean.class, false);
 75 
 76         valMap.put(Object.class, new Object() {
 77             @Override
 78             public String toString() {
 79                 return "I'm a test string.";
 80             }
 81         });
 82         valMap.put(TestObject.class, new TestObject());
 83         valMap.put(String.class, TESTSTR);
 84         valMap.put(StringBuilder.class, new StringBuilder("test"));
 85     }
 86 
 87     public static final class TestObject {
 88         TestObject() {
 89         }
 90 
 91         @Override
 92         public String toString() {
 93             return "TestObject String";
 94         }
 95     }
 96 
 97     @ParameterizedTest
 98     @MethodSource("getClassMethods")
 99     public void testModelTransform(Method method) {
100         CoreOp.FuncOp model = Op.ofMethod(method).orElseThrow();
101         CoreOp.FuncOp f_transformed = model.transform(new StringConcatTransformer());
102         Object[] args = prepArgs(method);
103 
104         System.out.println(model.toText());
105         System.out.println(f_transformed.toText());
106 
107         var interpreted = Interpreter.invoke(MethodHandles.lookup(), model, args);
108         var transformed_interpreted = Interpreter.invoke(MethodHandles.lookup(), f_transformed, args);
109 
110         Assertions.assertEquals(transformed_interpreted, interpreted);
111 
112     }
113 
114     @ParameterizedTest
115     @MethodSource("getClassMethods")
116     public void testSSAModelTransform(Method method) {
117         Object[] args = prepArgs(method);
118         testStringConcat(method, args);
119     }
120 
121     //Testing to make sure StringBuilders aren't caught up in the concat transformation
122     @Test
123     public void testStringBuilderUnchanged() {
124         Method method;
125 
126         try {
127             method = TestStringConcatTransform.class.getMethod("stringBuilderArgCheck", String.class, String.class, StringBuilder.class);
128         } catch (NoSuchMethodException e) {
129            throw new RuntimeException(e);
130         }
131         Object[] args = {"Foo", "Bar", new StringBuilder("test")};
132         testStringConcat(method, args);
133 
134         Assertions.assertEquals(args[2].toString(), "test");
135     }
136 
137     private void testStringConcat(Method method, Object[] args) {
138         CoreOp.FuncOp model = Op.ofMethod(method).orElseThrow();
139         CoreOp.FuncOp transformed_model = model.transform(new StringConcatTransformer());
140         CoreOp.FuncOp ssa_model = generateSSA(model);
141         CoreOp.FuncOp ssa_transformed_model = ssa_model.transform(new StringConcatTransformer());
142 
143         var model_interpreted = Interpreter.invoke(MethodHandles.lookup(), model, args);
144         var transformed_model_interpreted = Interpreter.invoke(MethodHandles.lookup(), transformed_model, args);
145         var ssa_interpreted = Interpreter.invoke(MethodHandles.lookup(), ssa_model, args);
146         var ssa_transformed_interpreted = Interpreter.invoke(MethodHandles.lookup(), ssa_transformed_model, args);
147         Object jvm_interpreted;
148         try {
149             jvm_interpreted = method.invoke(null, args);
150         } catch (IllegalAccessException | InvocationTargetException e) {
151             throw new RuntimeException(e);
152         }
153         Assertions.assertEquals(transformed_model_interpreted, model_interpreted);
154         Assertions.assertEquals(ssa_interpreted, transformed_model_interpreted);
155         Assertions.assertEquals(ssa_transformed_interpreted, ssa_interpreted);
156         Assertions.assertEquals(jvm_interpreted, ssa_transformed_interpreted);
157 
158     }
159 
160     public static Object[] prepArgs(Method m) {
161         Class<?>[] argTypes = m.getParameterTypes();
162         Object[] args = new Object[argTypes.length];
163         for (int i = 0; i < argTypes.length; i++) {
164             args[i] = valMap.get(argTypes[i]);
165         }
166         return args;
167     }
168 
169     public static Object[][] getClassMethods() {
170         return getTestMethods(TestStringConcatTransform.class);
171     }
172 
173     public static Object[][] getTestMethods(Class<?> clazz) {
174         Object[][] res = Arrays.stream(clazz.getMethods())
175                 .filter((method) -> method.isAnnotationPresent(CodeReflection.class))
176                 .map(m -> new Object[]{m})
177                 .toArray(Object[][]::new);
178         return res;
179     }
180 
181     static CoreOp.FuncOp generateSSA(CoreOp.FuncOp f) {
182         CoreOp.FuncOp lf = f.transform(OpTransformer.LOWERING_TRANSFORMER);
183         lf = SSA.transform(lf);
184         System.out.println(lf.toText());
185         return lf;
186     }
187 
188     @CodeReflection
189     public static String intConcat(int i, String s) {
190         return i + s + "hello" + 52;
191     }
192 
193 
194     @CodeReflection
195     public static String intConcatAssignment(int i, String s) {
196         String s1 = i + s;
197         return s1 + "hello" + 52;
198     }
199 
200     @CodeReflection
201     public static String intConcatExprAssignment(int i, String s) {
202         String r;
203         String inter = i + (r = s + "hello") + 52;
204         return r + inter;
205     }
206 
207     @CodeReflection
208     public static String intConcatWideExpr(int i, String s) {
209         String s1 = i + s;
210         return s1 + "hello" + 52 + "world" + 26 + "!";
211     }
212 
213     @CodeReflection
214     public static String intConcatDoubVar(int i, String s) {
215         String r;
216         String inter = i + (r = s + "hello") + 52;
217         String inter2 = i + (r = s + "hello") + 52 + inter;
218         return r + inter2;
219     }
220 
221     @CodeReflection
222     public static String intConcatNestedSplit(int i, String s) {
223         String q, r;
224         String inter = i + (q = r = s + "hello") + 52;
225         return q + r + inter;
226     }
227 
228     @CodeReflection
229     public static String nonLeftAssociativeTree(String a, String b, String c, String d) {
230         String s = (a + b) + (c + d);
231         return s;
232     }
233 
234     @CodeReflection
235     public static String stringBuilderCheck(String a, String d) {
236         StringBuilder sb = new StringBuilder("test");
237         String s = sb + a;
238         String t = s + d;
239         return t;
240     }
241 
242     @CodeReflection
243     public static String stringBuilderArgCheck(String a, String d, StringBuilder c) {
244         StringBuilder sb = c;
245         String s = sb + a;
246         String t = s + d;
247         return t;
248     }
249 
250     @CodeReflection
251     public static String leftAssociativeTree(String a, String b, String c, String d) {
252         String s = ((a + b) + c) + d;
253         return s;
254     }
255 
256     @CodeReflection
257     public static String rightAssociativeTree(String a, String b, String c, String d) {
258         String s = (a + (b + (c + d)));
259         return s;
260     }
261 
262     @CodeReflection
263     public static String widenPrimitives(short a, byte b, int c, int d) {
264         String s = (a + (b + (c + d + "hi")));
265         return s;
266     }
267 }
268