1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.lang.classfile.Attributes;
 25 import java.lang.classfile.ClassFile;
 26 import java.lang.classfile.ClassModel;
 27 import java.lang.classfile.CodeModel;
 28 import java.lang.classfile.MethodModel;
 29 import java.lang.classfile.MethodTransform;
 30 import java.lang.classfile.TypeKind;
 31 import java.lang.classfile.constantpool.ClassEntry;
 32 import java.lang.classfile.instruction.InvokeDynamicInstruction;
 33 import java.lang.constant.ClassDesc;
 34 import java.lang.constant.ConstantDesc;
 35 import java.lang.constant.ConstantDescs;
 36 import java.lang.constant.DirectMethodHandleDesc;
 37 import java.lang.constant.DynamicCallSiteDesc;
 38 import java.lang.constant.MethodHandleDesc;
 39 import java.lang.constant.MethodTypeDesc;
 40 import java.lang.invoke.CallSite;
 41 import java.lang.invoke.ConstantCallSite;
 42 import java.lang.invoke.LambdaConversionException;
 43 import java.lang.invoke.MethodHandle;
 44 import java.lang.invoke.MethodHandles;
 45 import java.lang.invoke.MethodType;
 46 import java.lang.reflect.AccessFlag;
 47 import java.lang.reflect.Method;
 48 import java.lang.reflect.Modifier;
 49 import java.nio.file.Files;
 50 import java.nio.file.Path;
 51 import java.util.*;
 52 import java.util.stream.Stream;
 53 
 54 import jdk.incubator.code.CodeTransformer;
 55 import jdk.incubator.code.Op;
 56 import jdk.incubator.code.Reflect;
 57 import jdk.incubator.code.bytecode.BytecodeGenerator;
 58 import jdk.incubator.code.dialect.core.CoreOp;
 59 import jdk.incubator.code.dialect.core.CoreType;
 60 import jdk.incubator.code.dialect.java.JavaOp;
 61 import jdk.incubator.code.runtime.ReflectableLambdaMetafactory;
 62 
 63 public final class Unreflect {
 64 
 65     static final ClassDesc CD_Reflect = Reflect.class.describeConstable().get();
 66     static final ClassDesc CD_Unreflect = Unreflect.class.describeConstable().get();
 67     static final ClassDesc CD_ReflectableLambdaMetafactory = ReflectableLambdaMetafactory.class.describeConstable().get();
 68 
 69     static boolean isReflective(MethodModel mm) {
 70         return mm.findAttribute(Attributes.runtimeVisibleAnnotations())
 71                  .map(aa -> aa.annotations().stream().anyMatch(a -> a.classSymbol().equals(CD_Reflect)))
 72                  .orElse(false);
 73     }
 74 
 75     static byte[] transform(ClassModel clm) {
 76         return ClassFile.of(ClassFile.ConstantPoolSharingOption.NEW_POOL).transformClass(clm, (clb, cle) -> {
 77             if (cle instanceof MethodModel mm) {
 78                 if (isReflective(mm)) {
 79                     clb.transformMethod(mm, MethodTransform.dropping(me -> me instanceof CodeModel)
 80                             .andThen(MethodTransform.endHandler(mb -> mb.withCode(cob -> {
 81                                 MethodTypeDesc mts = mm.methodTypeSymbol();
 82                                 boolean hasReceiver = !mm.flags().has(AccessFlag.STATIC);
 83                                 if (hasReceiver) {
 84                                     cob.aload(cob.receiverSlot());
 85                                 }
 86                                 for (int i = 0; i < mts.parameterCount(); i++) {
 87                                     cob.loadLocal(TypeKind.from(mts.parameterType(i)), cob.parameterSlot(i));
 88                                 }
 89                                 cob.invokedynamic(DynamicCallSiteDesc.of(
 90                                         ConstantDescs.ofCallsiteBootstrap(CD_Unreflect, "unreflect", ConstantDescs.CD_CallSite),
 91                                         mm.methodName().stringValue(),
 92                                         hasReceiver ? mts.insertParameterTypes(0, clm.thisClass().asSymbol()) : mts));
 93                                 cob.return_(TypeKind.from(mts.returnType()));
 94                             }))));
 95                 } else {
 96                     clb.transformMethod(mm, MethodTransform.transformingCode((cob, coe) -> {
 97                         DirectMethodHandleDesc bsm;
 98                         if (coe instanceof InvokeDynamicInstruction i
 99                                 && (bsm = i.bootstrapMethod()).owner().equals(CD_ReflectableLambdaMetafactory)) {
100                             // redirect metafactory and altMetafactory
101                             cob.invokedynamic(DynamicCallSiteDesc.of(
102                                     MethodHandleDesc.ofMethod(DirectMethodHandleDesc.Kind.STATIC,
103                                                               CD_Unreflect,
104                                                               bsm.methodName(),
105                                                               bsm.invocationType()),
106                                     i.name().stringValue(),
107                                     MethodTypeDesc.ofDescriptor(i.type().stringValue()),
108                                     i.bootstrapArgs().toArray(ConstantDesc[]::new)));
109                         } else {
110                             cob.with(coe);
111                         }
112                     }));
113                 }
114             } else {
115                 clb.with(cle);
116             }
117         });
118     }
119 
120     public static CallSite unreflect(MethodHandles.Lookup caller,
121                                      String methodName,
122                                      MethodType methodType) throws NoSuchMethodException {
123         for (Method m : caller.lookupClass().getDeclaredMethods()) {
124             int firstParam = (m.getModifiers() & Modifier.STATIC) == 0 ? 1 : 0;
125             if (m.getName().equals(methodName)
126                     && m.getReturnType() == methodType.returnType()
127                     && m.getParameterCount() == methodType.parameterCount() - firstParam
128                     && Arrays.equals(m.getParameterTypes(), 0, m.getParameterCount(),
129                                      methodType.parameterArray(), firstParam, methodType.parameterCount())) {
130                 return new ConstantCallSite(BytecodeGenerator.generate(caller, Op.ofMethod(m).orElseThrow()));
131             }
132         }
133         throw new NoSuchMethodException(caller.lookupClass().getName() + "." + methodName + methodType);
134     }
135 
136     public static CallSite metafactory(MethodHandles.Lookup caller,
137                                        String interfaceMethodName,
138                                        MethodType factoryType,
139                                        MethodType interfaceMethodType,
140                                        MethodHandle implementation,
141                                        MethodType dynamicMethodType) throws LambdaConversionException {
142         return ReflectableLambdaMetafactory.metafactory(caller,
143                                                         interfaceMethodName,
144                                                         factoryType,
145                                                         interfaceMethodType,
146                                                         unreflectLambdaImplementation(caller, interfaceMethodName),
147                                                         dynamicMethodType);
148     }
149 
150     public static CallSite altMetafactory(MethodHandles.Lookup caller,
151                                           String interfaceMethodName,
152                                           MethodType factoryType,
153                                           Object... args) throws LambdaConversionException {
154         args[1] = unreflectLambdaImplementation(caller, interfaceMethodName);
155         return ReflectableLambdaMetafactory.altMetafactory(caller,
156                                                            interfaceMethodName,
157                                                            factoryType,
158                                                            args);
159     }
160 
161     static MethodHandle unreflectLambdaImplementation(MethodHandles.Lookup caller, String interfaceMethodName)
162             throws LambdaConversionException {
163         try {
164             MethodHandle opHandle = caller.findStatic(caller.lookupClass(),
165                                                       interfaceMethodName.split("=")[1],
166                                                       MethodType.methodType(Op.class));
167             return BytecodeGenerator.generate(caller, unquoteLambda((CoreOp.FuncOp)opHandle.invoke()));
168         } catch (Throwable t) {
169             throw new LambdaConversionException(t);
170         }
171     }
172 
173     // flat QuotedOp and LambdaOp
174     static CoreOp.FuncOp unquoteLambda(CoreOp.FuncOp funcOp) {
175         int capturedValues = funcOp.parameters().size();
176         List<Op> ops = funcOp.body().entryBlock().ops();
177         JavaOp.LambdaOp lambda = (JavaOp.LambdaOp)((CoreOp.QuotedOp)ops.get(ops.size() - 2)).quotedOp();
178         return CoreOp.func(funcOp.funcName(), CoreType.functionType(
179                 lambda.body().yieldType(),
180                 Stream.of(funcOp.invokableType().parameterTypes(),
181                           lambda.invokableType().parameterTypes()).flatMap(List::stream).toList())).body(bb -> {
182             bb.context().mapBlock(funcOp.body().entryBlock(), bb.entryBlock());
183             bb.context().mapValues(funcOp.parameters(), bb.parameters().subList(0, capturedValues));
184             for (int i = 0; i < ops.size() - 2; i++) {
185                 Op o = ops.get(i);
186                 bb.context().mapValue(o.result(), bb.op(o));
187             }
188             bb.body(lambda.body(),
189                     bb.parameters().subList(capturedValues, bb.parameters().size()),
190                     bb.context(),
191                     CodeTransformer.COPYING_TRANSFORMER);
192         });
193     }
194 
195     public static void main(String[] args) throws Exception {
196         // process class files from arguments
197         var toUnreflect = new ArrayDeque<>(List.of(args));
198         var done = new HashSet<String>();
199         while (!toUnreflect.isEmpty()) {
200             String arg = toUnreflect.pop();
201             if (!arg.endsWith(".class")) arg += ".class";
202             if (done.add(arg)) {
203                 System.out.println("unreflecting " + arg);
204                 Path clsFile = Path.of(Unreflect.class.getResource(arg).toURI());
205                 ClassModel clm = ClassFile.of().parse(Files.readAllBytes(clsFile));
206                 // unreflect all nest members
207                 clm.findAttribute(Attributes.nestMembers())
208                         .ifPresent(nma -> toUnreflect.addAll(
209                                 nma.nestMembers().stream().map(ClassEntry::asInternalName).toList()));
210                 Files.write(clsFile, transform(clm));
211             }
212         }
213     }
214 }