1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27
28 import hat.backend.Backend;
29
30 import optkl.util.carriers.CommonCarrier;
31 import optkl.ifacemapper.BufferTracker;
32 import optkl.ifacemapper.MappableIface;
33
34
35 import java.lang.foreign.Arena;
36 import java.lang.invoke.MethodHandles;
37 import java.lang.reflect.Method;
38
39 import jdk.incubator.code.Reflect;
40 import jdk.incubator.code.Op;
41 import jdk.incubator.code.Quoted;
42 import jdk.incubator.code.dialect.java.JavaOp;
43
44 import java.util.HashMap;
45 import java.util.Map;
46 import java.util.ServiceLoader;
47 import java.util.function.Consumer;
48 import java.util.function.Predicate;
49
50 import static hat.backend.Backend.FIRST;
51 import static optkl.OpHelper.Invoke.getTargetInvoke;
52 import static optkl.OpHelper.Lambda.lambda;
53
54
55 /**
56 * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
57 * executing <b>NDRange</b> style execution.
58 * <p/>
59 * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
60 * compute to be performed.
61 * <p/>
62 * As well we either a <a href="backend/Backend.html">Backend</a> directly
63 * <pre>
64 * Accelerator accelerator =
65 * new Accelerator(MethodHandles.lookup(),
66 * new JavaMultiThreadedBackend());
67 * </pre>
68 * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
69 * loaded via Javas ServiceLoader mechanism
70 * {@code}
71 * <pre>
72 * Accelerator accelerator =
73 * new Accelerator(MethodHandles.lookup(),
74 * be -> be.name().startsWith("OpenCL));
75 * </pre>}
76 *
77 * @author Gary Frost
78 */
79 public class Accelerator implements CommonCarrier, BufferTracker {
80
81 private MethodHandles.Lookup lookup;
82 @Override public MethodHandles.Lookup lookup(){return lookup;}
83 public final Backend backend;
84
85
86 private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
87
88 public KernelContext range(NDRange<?,?> ndRange) {
89 return new KernelContext(ndRange);
90 }
91
92 protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
93 this(lookup, provider.get());
94 }
95 public Accelerator(MethodHandles.Lookup lookup) {
96 this(lookup, FIRST);
97 }
98
99 /**
100 * @param lookup
101 * @param backend
102 */
103 public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
104 this.lookup = lookup;
105 this.backend = backend;
106 }
107
108 /**
109 * @param lookup
110 * @param backendPredicate
111 */
112 public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
113 this(lookup, Backend.getBackend(backendPredicate));
114 }
115
116 @Override
117 public void preMutate(MappableIface b) {
118 if (backend instanceof BufferTracker) {
119 ((BufferTracker) backend).preMutate(b);
120 }
121 }
122
123 @Override
124 public void postMutate(MappableIface b) {
125 if (backend instanceof BufferTracker) {
126 ((BufferTracker) backend).postMutate(b);
127 }
128 }
129
130 @Override
131 public void preAccess(MappableIface b) {
132 if (backend instanceof BufferTracker) {
133 ((BufferTracker) backend).preAccess(b);
134 }
135 }
136
137 @Override
138 public void postAccess(MappableIface b) {
139 if (backend instanceof BufferTracker) {
140 ((BufferTracker) backend).postAccess(b);
141 }
142 }
143
144 @Override
145 public Arena arena() {
146 return backend.arena();
147 }
148
149 /**
150 * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
151 * <p/>
152 * So given a ComputeClass such as...
153 * <pre>
154 * public class MyComputeClass {
155 * @ Reflect
156 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
157 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
158 * }
159 *
160 * @ Reflect
161 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
162 * }
163 * }
164 * </pre>
165 * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
166 * <pre>
167 * accelerator.compute(cc ->
168 * MyCompute.doSomeWork(cc, arrayOfInt)
169 * );
170 * </pre>
171 */
172 @Reflect
173 @FunctionalInterface
174 public interface Compute extends Consumer<ComputeContext> {
175 }
176
177 // convenience
178 public Config config(){
179 return backend.config();
180 }
181
182 /**
183 * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
184 * <p>
185 * The entrypoint is wrapped in a {@link Compute} lambda.
186 *
187 * <pre>
188 * accelerator.compute(cc ->
189 * MyCompute.doSomeWork(cc, intArray)
190 * )
191 * </pre>
192 */
193 public void compute(Compute compute) {
194 Quoted quoted = Op.ofLambda(compute).orElseThrow();
195 JavaOp.LambdaOp lambda = (JavaOp.LambdaOp) quoted.op();
196 Method method = getTargetInvoke(this.lookup,lambda, ComputeContext.class).resolveMethodOrThrow();
197 // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
198 // The models of all compute and kernel methods are passed to the backend during creation
199 // The backend may well mutate the models.
200 // It will also use this opportunity to generate ISA specific code for the kernels.
201 ComputeContext computeContext = cache.computeIfAbsent(method, (_) -> new ComputeContext(this, method));
202 // Here we get the captured values from the lambda
203 Object[] args = lambda(lookup,lambda).getQuotedCapturedValues( quoted, method);
204 args[0] = computeContext;
205 // now ask the backend to execute
206 backend.dispatchCompute(computeContext, args);
207 }
208 }