1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng TestTransitiveInvokeModule
 27  */
 28 
 29 import org.testng.Assert;
 30 import org.testng.annotations.Test;
 31 
 32 import java.lang.invoke.MethodHandles;
 33 import java.lang.reflect.Executable;
 34 import java.lang.reflect.Method;
 35 import java.lang.reflect.code.OpTransformer;
 36 import java.lang.reflect.code.analysis.SSA;
 37 import java.lang.reflect.code.interpreter.Interpreter;
 38 import java.lang.reflect.code.op.CoreOp;
 39 import java.lang.reflect.code.type.MethodRef;
 40 import java.lang.runtime.CodeReflection;
 41 import java.util.*;
 42 import java.util.stream.Stream;
 43 
 44 public class TestTransitiveInvokeModule {
 45 
 46     @CodeReflection
 47     static void m(int i, List<Integer> l) {
 48         if (i < 0) {
 49             return;
 50         }
 51 
 52         n(i - 1, l);
 53     }
 54 
 55     @CodeReflection
 56     static void n(int i, List<Integer> l) {
 57         l.add(i);
 58         m(i - 1, l);
 59     }
 60 
 61     @Test
 62     public void test() {
 63         Optional<Method> om = Stream.of(TestTransitiveInvokeModule.class.getDeclaredMethods())
 64                 .filter(m -> m.getName().equals("m"))
 65                 .findFirst();
 66 
 67         CoreOp.ModuleOp module = createTransitiveInvokeModule(MethodHandles.lookup(), om.get());
 68         System.out.println(module.toText());
 69         module = module.transform(OpTransformer.LOWERING_TRANSFORMER);
 70         System.out.println(module.toText());
 71         module = SSA.transform(module);
 72         System.out.println(module.toText());
 73 
 74         module.functionTable().forEach((s, funcOp) -> {
 75             System.out.println(s + " -> " + funcOp);
 76         });
 77 
 78         List<Integer> r = new ArrayList<>();
 79         Interpreter.invoke(module.functionTable().firstEntry().getValue(), 10, r);
 80         Assert.assertEquals(r, List.of(9, 7, 5, 3, 1, -1));
 81     }
 82 
 83     static CoreOp.ModuleOp createTransitiveInvokeModule(MethodHandles.Lookup l,
 84                                                         Method m) {
 85         Optional<CoreOp.FuncOp> codeModel = m.getCodeModel();
 86         if (codeModel.isPresent()) {
 87             return createTransitiveInvokeModule(l, MethodRef.method(m), codeModel.get());
 88         } else {
 89             return CoreOp.module(List.of());
 90         }
 91     }
 92 
 93     static CoreOp.ModuleOp createTransitiveInvokeModule(MethodHandles.Lookup l,
 94                                                         MethodRef entryRef, CoreOp.FuncOp entry) {
 95         LinkedHashSet<MethodRef> funcsVisited = new LinkedHashSet<>();
 96         List<CoreOp.FuncOp> funcs = new ArrayList<>();
 97 
 98         record RefAndFunc(MethodRef r, CoreOp.FuncOp f) {
 99         }
100         Deque<RefAndFunc> work = new ArrayDeque<>();
101         work.push(new RefAndFunc(entryRef, entry));
102         while (!work.isEmpty()) {
103             RefAndFunc rf = work.pop();
104             if (!funcsVisited.add(rf.r)) {
105                 continue;
106             }
107 
108             CoreOp.FuncOp tf = rf.f.transform(rf.r.toString(), (block, op) -> {
109                 if (op instanceof CoreOp.InvokeOp iop) {
110                     MethodRef r = iop.invokeDescriptor();
111                     Method em = null;
112                     try {
113                         em = r.resolveToMethod(l);
114                     } catch (ReflectiveOperationException _) {
115                     }
116                     if (em instanceof Method m) {
117                         Optional<CoreOp.FuncOp> f = m.getCodeModel();
118                         if (f.isPresent()) {
119                             RefAndFunc call = new RefAndFunc(r, f.get());
120                             // Place model on work queue
121                             work.push(call);
122 
123                             // Replace invocation with function call
124                             block.op(CoreOp.funcCall(
125                                     call.r.toString(),
126                                     call.f.invokableType(),
127                                     block.context().getValues(iop.operands())));
128                             return block;
129                         }
130                     }
131                 }
132                 block.op(op);
133                 return block;
134             });
135             funcs.add(tf);
136         }
137 
138         return CoreOp.module(funcs);
139     }
140 }