1 /*
  2  * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 import hat.backend.Backend;
 28 
 29 import optkl.util.carriers.ArenaAndLookupCarrier;
 30 import optkl.ifacemapper.BufferTracker;
 31 import optkl.ifacemapper.MappableIface;
 32 
 33 
 34 import java.lang.foreign.Arena;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.reflect.Method;
 37 
 38 import jdk.incubator.code.Reflect;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Quoted;
 41 import jdk.incubator.code.dialect.java.JavaOp;
 42 
 43 import java.util.HashMap;
 44 import java.util.Map;
 45 import java.util.ServiceLoader;
 46 import java.util.function.Consumer;
 47 import java.util.function.Predicate;
 48 
 49 import static hat.backend.Backend.FIRST;
 50 import static optkl.OpHelper.Invoke.getTargetInvoke;
 51 import static optkl.OpHelper.Lambda.lambda;
 52 
 53 
 54 /**
 55  * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
 56  * executing <b>NDRange</b> style execution.
 57  * <p/>
 58  * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
 59  * compute to be performed.
 60  * <p/>
 61  * As well we either a <a href="backend/Backend.html">Backend</a> directly
 62  * <pre>
 63  * Accelerator accelerator =
 64  *    new Accelerator(MethodHandles.lookup(),
 65  *       new JavaMultiThreadedBackend());
 66  * </pre>
 67  * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
 68  * loaded via Javas ServiceLoader mechanism
 69  * {@code}
 70  * <pre>
 71  * Accelerator accelerator =
 72  *    new Accelerator(MethodHandles.lookup(),
 73  *        be -> be.name().startsWith("OpenCL));
 74  * </pre>}
 75  *
 76  * @author Gary Frost
 77  */
 78 public class Accelerator implements ArenaAndLookupCarrier,  BufferTracker {
 79 
 80     private final MethodHandles.Lookup lookup;
 81     @Override public MethodHandles.Lookup lookup(){return lookup;}
 82     public final Backend backend;
 83 
 84     private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
 85 
 86     public KernelContext range(NDRange ndRange) {
 87         return new KernelContext(ndRange);
 88     }
 89 
 90     protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
 91         this(lookup, provider.get());
 92     }
 93     public Accelerator(MethodHandles.Lookup lookup) {
 94         this(lookup, FIRST);
 95     }
 96 
 97     /**
 98      * @param lookup
 99      * @param backend
100      */
101     public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
102         this.lookup = lookup;
103         this.backend = backend;
104     }
105 
106     /**
107      * @param lookup
108      * @param backendPredicate
109      */
110     public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
111         this(lookup, Backend.getBackend(backendPredicate));
112     }
113 
114     @Override
115     public void preMutate(MappableIface mappableIface) {
116         if (backend instanceof BufferTracker bufferTracker) {
117             bufferTracker.preMutate(mappableIface);
118         }
119     }
120 
121     @Override
122     public void postMutate(MappableIface mappableIface) {
123         if (backend instanceof BufferTracker bufferTracker) {
124             bufferTracker.postMutate(mappableIface);
125         }
126     }
127 
128     @Override
129     public void preAccess(MappableIface mappableIface) {
130         if (backend instanceof BufferTracker bufferTracker) {
131             bufferTracker.preAccess(mappableIface);
132         }
133     }
134 
135     @Override
136     public void postAccess(MappableIface mappableIface) {
137         if (backend instanceof BufferTracker bufferTracker) {
138             bufferTracker.postAccess(mappableIface);
139         }
140     }
141 
142     @Override
143     public Arena arena() {
144         return backend.arena();
145     }
146 
147     /**
148      * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
149      * <p/>
150      * So given a ComputeClass such as...
151      * <pre>
152      *  public class MyComputeClass {
153      *    @ Reflect
154      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
155      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
156      *    }
157      *
158      *    @ Reflect
159      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
160      *    }
161      *  }
162      *  </pre>
163      * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
164      * <pre>
165      *  accelerator.compute(cc ->
166      *     MyCompute.doSomeWork(cc, arrayOfInt)
167      *  );
168      *  </pre>
169      */
170     @Reflect
171     @FunctionalInterface
172     public interface Compute extends Consumer<ComputeContext> {
173     }
174 
175     // convenience
176     public Config config(){
177         return backend.config();
178     }
179 
180     /**
181      * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
182      * <p>
183      * The entrypoint is wrapped in a {@link Compute} lambda.
184      *
185      * <pre>
186      * accelerator.compute(cc -&gt;
187      *     MyCompute.doSomeWork(cc, intArray)
188      * )
189      * </pre>
190      */
191     public void compute(Compute compute) {
192         Quoted<JavaOp.LambdaOp> quoted = Op.ofLambda(compute).orElseThrow();
193         JavaOp.LambdaOp lambda = quoted.op();
194         Method method = getTargetInvoke(this.lookup,lambda, ComputeContext.class).resolveMethodOrThrow();
195         // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
196         // The models of all compute and kernel methods are passed to the backend during creation
197         // The backend may well mutate the models.
198         // It will also use this opportunity to generate ISA specific code for the kernels.
199         ComputeContext computeContext = cache.computeIfAbsent(method, _ -> new ComputeContext(this, method));
200         // Here we get the captured values from the lambda
201         Object[] args = lambda(lookup,lambda).getQuotedCapturedValues( quoted, method);
202         args[0] = computeContext;
203         // now ask the backend to execute
204         backend.dispatchCompute(computeContext, args);
205     }
206 }