1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27
28 import hat.backend.Backend;
29
30 import optkl.util.carriers.CommonCarrier;
31 import optkl.ifacemapper.BufferTracker;
32 import optkl.ifacemapper.MappableIface;
33
34
35 import java.lang.foreign.Arena;
36 import java.lang.invoke.MethodHandles;
37 import java.lang.reflect.Method;
38
39 import jdk.incubator.code.Reflect;
40 import jdk.incubator.code.Op;
41 import jdk.incubator.code.Quoted;
42 import jdk.incubator.code.dialect.java.JavaOp;
43
44 import java.util.HashMap;
45 import java.util.Map;
46 import java.util.ServiceLoader;
47 import java.util.function.Consumer;
48 import java.util.function.Predicate;
49
50 import static hat.backend.Backend.FIRST;
51 import static optkl.OpTkl.getQuotedCapturedValues;
52 import static optkl.OpTkl.getTargetInvokeOp;
53 import static optkl.OpTkl.methodOrThrow;
54
55
56 /**
57 * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
58 * executing <b>NDRange</b> style execution.
59 * <p/>
60 * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
61 * compute to be performed.
62 * <p/>
63 * As well we either a <a href="backend/Backend.html">Backend</a> directly
64 * <pre>
65 * Accelerator accelerator =
66 * new Accelerator(MethodHandles.lookup(),
67 * new JavaMultiThreadedBackend());
68 * </pre>
69 * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
70 * loaded via Javas ServiceLoader mechanism
71 * {@code}
72 * <pre>
73 * Accelerator accelerator =
74 * new Accelerator(MethodHandles.lookup(),
75 * be -> be.name().startsWith("OpenCL));
76 * </pre>}
77 *
78 * @author Gary Frost
79 */
80 public class Accelerator implements CommonCarrier, BufferTracker {
81
82 private MethodHandles.Lookup lookup;
83 @Override public MethodHandles.Lookup lookup(){return lookup;}
84 public final Backend backend;
85
86
87 private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
88
89 public KernelContext range(NDRange<?,?> ndRange) {
90 return new KernelContext(ndRange);
91 }
92
93 protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
94 this(lookup, provider.get());
95 }
96 public Accelerator(MethodHandles.Lookup lookup) {
97 this(lookup, FIRST);
98 }
99
100 /**
101 * @param lookup
102 * @param backend
103 */
104 public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
105 this.lookup = lookup;
106 this.backend = backend;
107 }
108
109 /**
110 * @param lookup
111 * @param backendPredicate
112 */
113 public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
114 this(lookup, Backend.getBackend(backendPredicate));
115 }
116
117 @Override
118 public void preMutate(MappableIface b) {
119 if (backend instanceof BufferTracker) {
120 ((BufferTracker) backend).preMutate(b);
121 }
122 }
123
124 @Override
125 public void postMutate(MappableIface b) {
126 if (backend instanceof BufferTracker) {
127 ((BufferTracker) backend).postMutate(b);
128 }
129 }
130
131 @Override
132 public void preAccess(MappableIface b) {
133 if (backend instanceof BufferTracker) {
134 ((BufferTracker) backend).preAccess(b);
135 }
136 }
137
138 @Override
139 public void postAccess(MappableIface b) {
140 if (backend instanceof BufferTracker) {
141 ((BufferTracker) backend).postAccess(b);
142 }
143 }
144
145 @Override
146 public Arena arena() {
147 return backend.arena();
148 }
149
150 /**
151 * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
152 * <p/>
153 * So given a ComputeClass such as...
154 * <pre>
155 * public class MyComputeClass {
156 * @ Reflect
157 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
158 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
159 * }
160 *
161 * @ Reflect
162 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
163 * }
164 * }
165 * </pre>
166 * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
167 * <pre>
168 * accelerator.compute(cc ->
169 * MyCompute.doSomeWork(cc, arrayOfInt)
170 * );
171 * </pre>
172 */
173 @Reflect
174 @FunctionalInterface
175 public interface Compute extends Consumer<ComputeContext> {
176 }
177
178 // convenience
179 public Config config(){
180 return backend.config();
181 }
182
183 /**
184 * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
185 * <p>
186 * The entrypoint is wrapped in a {@link Compute} lambda.
187 *
188 * <pre>
189 * accelerator.compute(cc ->
190 * MyCompute.doSomeWork(cc, intArray)
191 * )
192 * </pre>
193 */
194 public void compute(Compute compute) {
195 Quoted quoted = Op.ofQuotable(compute).orElseThrow();
196 JavaOp.LambdaOp lambda = (JavaOp.LambdaOp) quoted.op();
197 Method method = methodOrThrow(lookup,getTargetInvokeOp(this.lookup,lambda, ComputeContext.class));
198 // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
199 // The models of all compute and kernel methods are passed to the backend during creation
200 // The backend may well mutate the models.
201 // It will also use this opportunity to generate ISA specific code for the kernels.
202 ComputeContext computeContext = cache.computeIfAbsent(method, (_) -> new ComputeContext(this, method));
203 // Here we get the captured values from the lambda
204 Object[] args = getQuotedCapturedValues(lambda, quoted, method);
205 args[0] = computeContext;
206 // now ask the backend to execute
207 backend.dispatchCompute(computeContext, args);
208 }
209 }