1 /*
2 * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat;
26
27 import hat.backend.Backend;
28
29 import optkl.util.carriers.ArenaAndLookupCarrier;
30 import optkl.ifacemapper.BufferTracker;
31 import optkl.ifacemapper.MappableIface;
32
33
34 import java.lang.foreign.Arena;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.reflect.Method;
37
38 import jdk.incubator.code.Reflect;
39 import jdk.incubator.code.Op;
40 import jdk.incubator.code.Quoted;
41 import jdk.incubator.code.dialect.java.JavaOp;
42
43 import java.util.HashMap;
44 import java.util.Map;
45 import java.util.ServiceLoader;
46 import java.util.function.Consumer;
47 import java.util.function.Predicate;
48
49 import static hat.backend.Backend.FIRST;
50 import static optkl.OpHelper.Invoke.getTargetInvoke;
51 import static optkl.OpHelper.Lambda.lambda;
52
53
54 /**
55 * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
56 * executing <b>NDRange</b> style execution.
57 * <p/>
58 * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
59 * compute to be performed.
60 * <p/>
61 * As well we either a <a href="backend/Backend.html">Backend</a> directly
62 * <pre>
63 * Accelerator accelerator =
64 * new Accelerator(MethodHandles.lookup(),
65 * new JavaMultiThreadedBackend());
66 * </pre>
67 * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
68 * loaded via Javas ServiceLoader mechanism
69 * {@code}
70 * <pre>
71 * Accelerator accelerator =
72 * new Accelerator(MethodHandles.lookup(),
73 * be -> be.name().startsWith("OpenCL));
74 * </pre>}
75 *
76 * @author Gary Frost
77 */
78 public class Accelerator implements ArenaAndLookupCarrier, BufferTracker {
79
80 private final MethodHandles.Lookup lookup;
81 @Override public MethodHandles.Lookup lookup(){return lookup;}
82 public final Backend backend;
83
84 private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
85
86 public KernelContext range(NDRange ndRange) {
87 return new KernelContext(ndRange);
88 }
89
90 protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
91 this(lookup, provider.get());
92 }
93 public Accelerator(MethodHandles.Lookup lookup) {
94 this(lookup, FIRST);
95 }
96
97 /**
98 * @param lookup
99 * @param backend
100 */
101 public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
102 this.lookup = lookup;
103 this.backend = backend;
104 }
105
106 /**
107 * @param lookup
108 * @param backendPredicate
109 */
110 public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
111 this(lookup, Backend.getBackend(backendPredicate));
112 }
113
114 @Override
115 public void preMutate(MappableIface mappableIface) {
116 if (backend instanceof BufferTracker bufferTracker) {
117 bufferTracker.preMutate(mappableIface);
118 }
119 }
120
121 @Override
122 public void postMutate(MappableIface mappableIface) {
123 if (backend instanceof BufferTracker bufferTracker) {
124 bufferTracker.postMutate(mappableIface);
125 }
126 }
127
128 @Override
129 public void preAccess(MappableIface mappableIface) {
130 if (backend instanceof BufferTracker bufferTracker) {
131 bufferTracker.preAccess(mappableIface);
132 }
133 }
134
135 @Override
136 public void postAccess(MappableIface mappableIface) {
137 if (backend instanceof BufferTracker bufferTracker) {
138 bufferTracker.postAccess(mappableIface);
139 }
140 }
141
142 @Override
143 public Arena arena() {
144 return backend.arena();
145 }
146
147 /**
148 * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
149 * <p/>
150 * So given a ComputeClass such as...
151 * <pre>
152 * public class MyComputeClass {
153 * @ Reflect
154 * public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
155 * arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
156 * }
157 *
158 * @ Reflect
159 * static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
160 * }
161 * }
162 * </pre>
163 * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
164 * <pre>
165 * accelerator.compute(cc ->
166 * MyCompute.doSomeWork(cc, arrayOfInt)
167 * );
168 * </pre>
169 */
170 @Reflect
171 @FunctionalInterface
172 public interface Compute extends Consumer<ComputeContext> {
173 }
174
175 // convenience
176 public Config config(){
177 return backend.config();
178 }
179
180 /**
181 * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
182 * <p>
183 * The entrypoint is wrapped in a {@link Compute} lambda.
184 *
185 * <pre>
186 * accelerator.compute(cc ->
187 * MyCompute.doSomeWork(cc, intArray)
188 * )
189 * </pre>
190 */
191 public void compute(Compute compute) {
192 Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(compute).orElseThrow();
193 JavaOp.LambdaOp lambda = quoted.op();
194 Method method = getTargetInvoke(this.lookup,lambda, ComputeContext.class).resolveMethodOrThrow();
195 // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
196 // The models of all compute and kernel methods are passed to the backend during creation
197 // The backend may well mutate the models.
198 // It will also use this opportunity to generate ISA specific code for the kernels.
199 ComputeContext computeContext = cache.computeIfAbsent(method, _ -> new ComputeContext(this, method));
200 // Here we get the captured values from the lambda
201 Object[] args = lambda(lookup,lambda).getQuotedCapturedValues( quoted, method);
202 args[0] = computeContext;
203 // now ask the backend to execute
204 backend.dispatchCompute(computeContext, args);
205 }
206 }