1 package oracle.code.onnx.ir;
  2 
  3 import jdk.incubator.code.Op;
  4 import jdk.incubator.code.extern.ExternalizedOp;
  5 import jdk.incubator.code.extern.OpFactory;
  6 
  7 import java.lang.annotation.ElementType;
  8 import java.lang.annotation.Retention;
  9 import java.lang.annotation.RetentionPolicy;
 10 import java.lang.annotation.Target;
 11 import java.lang.invoke.MethodHandle;
 12 import java.lang.invoke.MethodHandles;
 13 import java.lang.reflect.Constructor;
 14 import java.lang.reflect.Method;
 15 import java.lang.reflect.Modifier;
 16 import java.util.HashMap;
 17 import java.util.Map;
 18 import java.util.function.Function;
 19 
 20 public final class OpFactoryHelper {
 21 
 22     /**
 23      * An operation declaration annotation.
 24      * <p>
 25      * This annotation may be declared on a concrete class implementing an {@link Op operation} whose name is a constant
 26      * that can be declared as this attribute's value.
 27      * <p>
 28      * Tooling can process declarations of this annotation to build a factory for constructing operations from their name.
 29      */
 30     @Retention(RetentionPolicy.RUNTIME)
 31     @Target(ElementType.TYPE)
 32     public @interface OpDeclaration {
 33         /**
 34          * {@return the operation name}
 35          */
 36         String value();
 37     }
 38 
 39     /**
 40      * A class value for lazily computing an operation factory for {@link Op operation} classes
 41      * annotated with {@link OpFactoryHelper.OpDeclaration} and enclosed within a given class to compute over.
 42      * <p>
 43      * Each enclosed class annotated with {@code OpDeclaration} must declare a public static method named {@code create}
 44      * with one parameter type of {@link ExternalizedOp} and return type that is the concrete class type.
 45      * Alternatively, the concrete class must declare public constructor with one parameter type of
 46      * {@link ExternalizedOp}.
 47      */
 48     public static final ClassValue<OpFactory> OP_FACTORY = new ClassValue<>() {
 49         @Override
 50         protected OpFactory computeValue(Class<?> c) {
 51             // @@@ See https://bugs.openjdk.org/browse/JDK-8321207
 52             final Map<String, Class<? extends Op>> opMapping = createOpMapping(c);
 53 
 54             return def -> {
 55                 var opClass = opMapping.get(def.name());
 56                 if (opClass == null) {
 57                     return null;
 58                 }
 59 
 60                 Op op = constructOp(opClass, def);
 61                 // Set location if available
 62                 if (op != null && def.location() != null) {
 63                     op.setLocation(def.location());
 64                 }
 65                 return op;
 66             };
 67         }
 68     };
 69 
 70     private static Map<String, Class<? extends Op>> createOpMapping(Class<?> opClasses) {
 71         Map<String, Class<? extends Op>> mapping = new HashMap<>();
 72         for (Class<?> opClass : opClasses.getNestMembers()) {
 73             if (opClass.isAnnotationPresent(OpDeclaration.class)) {
 74                 if (!Modifier.isPublic(opClass.getModifiers())) {
 75                     throw new InternalError("Operation class not public: " + opClass.getName());
 76                 }
 77 
 78                 if (!Op.class.isAssignableFrom(opClass)) {
 79                     throw new InternalError("Operation class is not assignable to Op: " + opClass);
 80                 }
 81 
 82                 MethodHandle handle = getOpConstructorMethodHandle(opClass);
 83                 if (handle == null) {
 84                     throw new InternalError("Operation constructor for operation class not found: " + opClass.getName());
 85                 }
 86 
 87                 if (!Op.class.isAssignableFrom(handle.type().returnType())) {
 88                     throw new InternalError("Operation constructor does not return an Op: " + handle);
 89                 }
 90 
 91                 String opName = opClass.getAnnotation(OpDeclaration.class).value();
 92                 @SuppressWarnings("unchecked")
 93                 var opClassCast = (Class<Op>) opClass;
 94                 mapping.put(opName, opClassCast);
 95             }
 96         }
 97         return mapping;
 98     }
 99 
100     private static MethodHandle getOpConstructorMethodHandle(Class<?> opClass) {
101         Method method = null;
102         try {
103             method = opClass.getMethod("create", ExternalizedOp.class);
104         } catch (NoSuchMethodException e) {
105         }
106 
107         if (method != null) {
108             if (!Modifier.isStatic(method.getModifiers())) {
109                 throw new InternalError("Operation constructor is not a static method: " + method);
110             }
111 
112             try {
113                 return MethodHandles.publicLookup().unreflect(method);
114             } catch (IllegalAccessException e) {
115                 throw new InternalError("Inaccessible operation constructor for operation: " +
116                         method);
117             }
118         }
119 
120         Constructor<?> constructor;
121         try {
122             constructor = opClass.getConstructor(ExternalizedOp.class);
123         } catch (NoSuchMethodException e) {
124             return null;
125         }
126 
127         try {
128             return MethodHandles.publicLookup().unreflectConstructor(constructor);
129         } catch (IllegalAccessException e) {
130             throw new InternalError("Inaccessible operation constructor for operation: " +
131                     constructor);
132         }
133     }
134 
135     private static Op constructOp(Class<? extends Op> opClass, ExternalizedOp opDef) {
136         class Enclosed {
137             private static final ClassValue<Function<ExternalizedOp, Op>> OP_CONSTRUCTOR = new ClassValue<>() {
138                 @Override
139                 protected Function<ExternalizedOp, Op> computeValue(Class<?> opClass) {
140                     final MethodHandle opConstructorMH = getOpConstructorMethodHandle(opClass);
141                     assert opConstructorMH != null;
142 
143                     return operationDefinition -> {
144                         try {
145                             return (Op) opConstructorMH.invoke(operationDefinition);
146                         } catch (RuntimeException | Error e) {
147                             throw e;
148                         } catch (Throwable t) {
149                             throw new RuntimeException(t);
150                         }
151                     };
152                 }
153             };
154         }
155         return Enclosed.OP_CONSTRUCTOR.get(opClass).apply(opDef);
156     }
157 }