1 /*
2 * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27 import hat.backend.Backend;
28
29 import optkl.codebuilders.BabylonOpDispatcher;
30 import optkl.util.carriers.ArenaAndLookupCarrier;
31 import optkl.ifacemapper.BufferTracker;
32 import optkl.ifacemapper.MappableIface;
33
34
35 import java.lang.foreign.Arena;
36 import java.lang.invoke.MethodHandles;
37 import java.lang.reflect.Method;
38
39 import jdk.incubator.code.Reflect;
40 import jdk.incubator.code.Op;
41 import jdk.incubator.code.Quoted;
42 import jdk.incubator.code.dialect.java.JavaOp;
43
44 import java.util.HashMap;
45 import java.util.Map;
46 import java.util.ServiceLoader;
47 import java.util.function.Consumer;
48 import java.util.function.Predicate;
49
50 import static hat.backend.Backend.FIRST;
51 import static optkl.OpHelper.Invoke.getTargetInvoke;
52 import static optkl.OpHelper.Lambda.lambda;
53
54
55 /**
56 * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
57 * executing <b>NDRange</b> style execution.
58 * <p/>
59 * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
60 * compute to be performed.
61 * <p/>
62 * As well we either a <a href="backend/Backend.html">Backend</a> directly
63 * <pre>
64 * Accelerator accelerator =
65 * new Accelerator(MethodHandles.lookup(),
66 * new JavaMultiThreadedBackend());
67 * </pre>
68 * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
69 * loaded via Javas ServiceLoader mechanism
70 * {@code}
71 * <pre>
72 * Accelerator accelerator =
73 * new Accelerator(MethodHandles.lookup(),
74 * be -> be.name().startsWith("OpenCL));
75 * </pre>}
76 *
77 * @author Gary Frost
78 */
79 public class Accelerator implements ArenaAndLookupCarrier, BufferTracker {
80
81 private final MethodHandles.Lookup lookup;
82 @Override public MethodHandles.Lookup lookup(){return lookup;}
83 public final Backend backend;
84
85 private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
86
87 public KernelContext range(NDRange ndRange) {
88 return new KernelContext(ndRange);
89 }
90
91 protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
92 this(lookup, provider.get());
93 }
94 public Accelerator(MethodHandles.Lookup lookup) {
95 this(lookup, FIRST);
96 }
97
98 /**
99 * @param lookup
100 * @param backend
101 */
102 public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
103 this.lookup = lookup;
104 this.backend = backend;
105 }
106
107 /**
108 * @param lookup
109 * @param backendPredicate
110 */
111 public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
112 this(lookup, Backend.getBackend(backendPredicate));
113 }
114
115 @Override
116 public void preMutate(MappableIface mappableIface) {
117 if (backend instanceof BufferTracker bufferTracker) {
118 bufferTracker.preMutate(mappableIface);
119 }
120 }
121
122 @Override
123 public void postMutate(MappableIface mappableIface) {
124 if (backend instanceof BufferTracker bufferTracker) {
125 bufferTracker.postMutate(mappableIface);
126 }
127 }
128
129 @Override
130 public void preAccess(MappableIface mappableIface) {
131 if (backend instanceof BufferTracker bufferTracker) {
132 bufferTracker.preAccess(mappableIface);
133 }
134 }
135
136 @Override
137 public void postAccess(MappableIface mappableIface) {
138 if (backend instanceof BufferTracker bufferTracker) {
139 bufferTracker.postAccess(mappableIface);
140 }
141 }
142
143 @Override
144 public Arena arena() {
145 return backend.arena();
146 }
147
148 /**
149 * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
150 * <p/>
151 * So given a ComputeClass such as...
152 * <pre>
153 * public class MyComputeClass {
154 * @ Reflect
155 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
156 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
157 * }
158 *
159 * @ Reflect
160 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
161 * }
162 * }
163 * </pre>
164 * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
165 * <pre>
166 * accelerator.compute(cc ->
167 * MyCompute.doSomeWork(cc, arrayOfInt)
168 * );
169 * </pre>
170 */
171 @Reflect
172 @FunctionalInterface
173 public interface Compute extends Consumer<ComputeContext> {
174 }
175
176 // convenience
177 public Config config(){
178 return backend.config();
179 }
180
181 /**
182 * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
183 * <p>
184 * The entrypoint is wrapped in a {@link Compute} lambda.
185 *
186 * <pre>
187 * accelerator.compute(cc ->
188 * MyCompute.doSomeWork(cc, intArray)
189 * )
190 * </pre>
191 */
192 public void compute(Compute compute) {
193 Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(compute).orElseThrow();
194 JavaOp.LambdaOp lambda = quoted.op();
195 Method method = getTargetInvoke(this.lookup,lambda, ComputeContext.class).resolveMethodOrThrow();
196 // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
197 // The models of all compute and kernel methods are passed to the backend during creation
198 // The backend may well mutate the models.
199 // It will also use this opportunity to generate ISA specific code for the kernels.
200 ComputeContext computeContext = cache.computeIfAbsent(method, _ -> new ComputeContext(this, method));
201 // Here we get the captured values from the lambda
202 Object[] args = lambda(lookup,lambda).getQuotedCapturedValues( quoted, method);
203 args[0] = computeContext;
204 // now ask the backend to execute
205 backend.dispatchCompute(computeContext, args);
206 }
207 }