1 package oracle.code.onnx.ir;
2
3 import jdk.incubator.code.Op;
4 import jdk.incubator.code.extern.ExternalizedOp;
5 import jdk.incubator.code.extern.OpFactory;
6
7 import java.lang.annotation.ElementType;
8 import java.lang.annotation.Retention;
9 import java.lang.annotation.RetentionPolicy;
10 import java.lang.annotation.Target;
11 import java.lang.invoke.MethodHandle;
12 import java.lang.invoke.MethodHandles;
13 import java.lang.reflect.Constructor;
14 import java.lang.reflect.Method;
15 import java.lang.reflect.Modifier;
16 import java.util.HashMap;
17 import java.util.Map;
18 import java.util.function.Function;
19
20 public final class OpFactoryHelper {
21
22 /**
23 * An operation declaration annotation.
24 * <p>
25 * This annotation may be declared on a concrete class implementing an {@link Op operation} whose name is a constant
26 * that can be declared as this attribute's value.
27 * <p>
28 * Tooling can process declarations of this annotation to build a factory for constructing operations from their name.
29 */
30 @Retention(RetentionPolicy.RUNTIME)
31 @Target(ElementType.TYPE)
32 public @interface OpDeclaration {
33 /**
34 * {@return the operation name}
35 */
36 String value();
37 }
38
39 /**
40 * A class value for lazily computing an operation factory for {@link Op operation} classes
41 * annotated with {@link OpFactoryHelper.OpDeclaration} and enclosed within a given class to compute over.
42 * <p>
43 * Each enclosed class annotated with {@code OpDeclaration} must declare a public static method named {@code create}
44 * with one parameter type of {@link ExternalizedOp} and return type that is the concrete class type.
45 * Alternatively, the concrete class must declare public constructor with one parameter type of
46 * {@link ExternalizedOp}.
47 */
48 public static final ClassValue<OpFactory> OP_FACTORY = new ClassValue<>() {
49 @Override
50 protected OpFactory computeValue(Class<?> c) {
51 // @@@ See https://bugs.openjdk.org/browse/JDK-8321207
52 final Map<String, Class<? extends Op>> opMapping = createOpMapping(c);
53
54 return def -> {
55 var opClass = opMapping.get(def.name());
56 if (opClass == null) {
57 return null;
58 }
59
60 Op op = constructOp(opClass, def);
61 // Set location if available
62 if (op != null && def.location() != null) {
63 op.setLocation(def.location());
64 }
65 return op;
66 };
67 }
68 };
69
70 private static Map<String, Class<? extends Op>> createOpMapping(Class<?> opClasses) {
71 Map<String, Class<? extends Op>> mapping = new HashMap<>();
72 for (Class<?> opClass : opClasses.getNestMembers()) {
73 if (opClass.isAnnotationPresent(OpDeclaration.class)) {
74 if (!Modifier.isPublic(opClass.getModifiers())) {
75 throw new InternalError("Operation class not public: " + opClass.getName());
76 }
77
78 if (!Op.class.isAssignableFrom(opClass)) {
79 throw new InternalError("Operation class is not assignable to Op: " + opClass);
80 }
81
82 MethodHandle handle = getOpConstructorMethodHandle(opClass);
83 if (handle == null) {
84 throw new InternalError("Operation constructor for operation class not found: " + opClass.getName());
85 }
86
87 if (!Op.class.isAssignableFrom(handle.type().returnType())) {
88 throw new InternalError("Operation constructor does not return an Op: " + handle);
89 }
90
91 String opName = opClass.getAnnotation(OpDeclaration.class).value();
92 @SuppressWarnings("unchecked")
93 var opClassCast = (Class<Op>) opClass;
94 mapping.put(opName, opClassCast);
95 }
96 }
97 return mapping;
98 }
99
100 private static MethodHandle getOpConstructorMethodHandle(Class<?> opClass) {
101 Method method = null;
102 try {
103 method = opClass.getMethod("create", ExternalizedOp.class);
104 } catch (NoSuchMethodException e) {
105 }
106
107 if (method != null) {
108 if (!Modifier.isStatic(method.getModifiers())) {
109 throw new InternalError("Operation constructor is not a static method: " + method);
110 }
111
112 try {
113 return MethodHandles.publicLookup().unreflect(method);
114 } catch (IllegalAccessException e) {
115 throw new InternalError("Inaccessible operation constructor for operation: " +
116 method);
117 }
118 }
119
120 Constructor<?> constructor;
121 try {
122 constructor = opClass.getConstructor(ExternalizedOp.class);
123 } catch (NoSuchMethodException e) {
124 return null;
125 }
126
127 try {
128 return MethodHandles.publicLookup().unreflectConstructor(constructor);
129 } catch (IllegalAccessException e) {
130 throw new InternalError("Inaccessible operation constructor for operation: " +
131 constructor);
132 }
133 }
134
135 private static Op constructOp(Class<? extends Op> opClass, ExternalizedOp opDef) {
136 class Enclosed {
137 private static final ClassValue<Function<ExternalizedOp, Op>> OP_CONSTRUCTOR = new ClassValue<>() {
138 @Override
139 protected Function<ExternalizedOp, Op> computeValue(Class<?> opClass) {
140 final MethodHandle opConstructorMH = getOpConstructorMethodHandle(opClass);
141 assert opConstructorMH != null;
142
143 return operationDefinition -> {
144 try {
145 return (Op) opConstructorMH.invoke(operationDefinition);
146 } catch (RuntimeException | Error e) {
147 throw e;
148 } catch (Throwable t) {
149 throw new RuntimeException(t);
150 }
151 };
152 }
153 };
154 }
155 return Enclosed.OP_CONSTRUCTOR.get(opClass).apply(opDef);
156 }
157 }