1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 
 28 import hat.backend.Backend;
 29 
 30 import optkl.util.carriers.CommonCarrier;
 31 import optkl.ifacemapper.BufferTracker;
 32 import optkl.ifacemapper.MappableIface;
 33 
 34 
 35 import java.lang.foreign.Arena;
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.reflect.Method;
 38 
 39 import jdk.incubator.code.Reflect;
 40 import jdk.incubator.code.Op;
 41 import jdk.incubator.code.Quoted;
 42 import jdk.incubator.code.dialect.java.JavaOp;
 43 
 44 import java.util.HashMap;
 45 import java.util.Map;
 46 import java.util.ServiceLoader;
 47 import java.util.function.Consumer;
 48 import java.util.function.Predicate;
 49 
 50 import static hat.backend.Backend.FIRST;
 51 import static optkl.OpTkl.getQuotedCapturedValues;
 52 import static optkl.OpTkl.getTargetInvokeOp;
 53 import static optkl.OpTkl.methodOrThrow;
 54 
 55 
 56 /**
 57  * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
 58  * executing <b>NDRange</b> style execution.
 59  * <p/>
 60  * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
 61  * compute to be performed.
 62  * <p/>
 63  * As well we either a <a href="backend/Backend.html">Backend</a> directly
 64  * <pre>
 65  * Accelerator accelerator =
 66  *    new Accelerator(MethodHandles.lookup(),
 67  *       new JavaMultiThreadedBackend());
 68  * </pre>
 69  * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
 70  * loaded via Javas ServiceLoader mechanism
 71  * {@code}
 72  * <pre>
 73  * Accelerator accelerator =
 74  *    new Accelerator(MethodHandles.lookup(),
 75  *        be -> be.name().startsWith("OpenCL));
 76  * </pre>}
 77  *
 78  * @author Gary Frost
 79  */
 80 public class Accelerator implements CommonCarrier,  BufferTracker {
 81 
 82     private MethodHandles.Lookup lookup;
 83     @Override public MethodHandles.Lookup lookup(){return lookup;}
 84     public final Backend backend;
 85 
 86 
 87     private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
 88 
 89     public KernelContext range(NDRange<?,?> ndRange) {
 90         return new KernelContext(ndRange);
 91     }
 92 
 93     protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
 94         this(lookup, provider.get());
 95     }
 96     public Accelerator(MethodHandles.Lookup lookup) {
 97         this(lookup, FIRST);
 98     }
 99 
100     /**
101      * @param lookup
102      * @param backend
103      */
104     public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
105         this.lookup = lookup;
106         this.backend = backend;
107     }
108 
109     /**
110      * @param lookup
111      * @param backendPredicate
112      */
113     public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
114         this(lookup, Backend.getBackend(backendPredicate));
115     }
116 
117     @Override
118     public void preMutate(MappableIface b) {
119         if (backend instanceof BufferTracker) {
120             ((BufferTracker) backend).preMutate(b);
121         }
122     }
123 
124     @Override
125     public void postMutate(MappableIface b) {
126         if (backend instanceof BufferTracker) {
127             ((BufferTracker) backend).postMutate(b);
128         }
129     }
130 
131     @Override
132     public void preAccess(MappableIface b) {
133         if (backend instanceof BufferTracker) {
134             ((BufferTracker) backend).preAccess(b);
135         }
136     }
137 
138     @Override
139     public void postAccess(MappableIface b) {
140         if (backend instanceof BufferTracker) {
141             ((BufferTracker) backend).postAccess(b);
142         }
143     }
144 
145     @Override
146     public Arena arena() {
147         return backend.arena();
148     }
149 
150     /**
151      * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
152      * <p/>
153      * So given a ComputeClass such as...
154      * <pre>
155      *  public class MyComputeClass {
156      *    @ Reflect
157      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
158      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
159      *    }
160      *
161      *    @ Reflect
162      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
163      *    }
164      *  }
165      *  </pre>
166      * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code Compute}
167      * <pre>
168      *  accelerator.compute(cc ->
169      *     MyCompute.doSomeWork(cc, arrayOfInt)
170      *  );
171      *  </pre>
172      */
173     @Reflect
174     @FunctionalInterface
175     public interface Compute extends Consumer<ComputeContext> {
176     }
177 
178     // convenience
179     public Config config(){
180         return backend.config();
181     }
182 
183     /**
184      * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
185      * <p>
186      * The entrypoint is wrapped in a {@link Compute} lambda.
187      *
188      * <pre>
189      * accelerator.compute(cc -&gt;
190      *     MyCompute.doSomeWork(cc, intArray)
191      * )
192      * </pre>
193      */
194     public void compute(Compute compute) {
195         Quoted quoted = Op.ofQuotable(compute).orElseThrow();
196         JavaOp.LambdaOp lambda = (JavaOp.LambdaOp) quoted.op();
197         Method method = methodOrThrow(lookup,getTargetInvokeOp(this.lookup,lambda, ComputeContext.class));
198         // Create (or get cached) a compute context which closes over compute entrypoint and reachable kernels.
199         // The models of all compute and kernel methods are passed to the backend during creation
200         // The backend may well mutate the models.
201         // It will also use this opportunity to generate ISA specific code for the kernels.
202         ComputeContext computeContext = cache.computeIfAbsent(method, (_) -> new ComputeContext(this, method));
203         // Here we get the captured values from the lambda
204         Object[] args = getQuotedCapturedValues(lambda, quoted, method);
205         args[0] = computeContext;
206         // now ask the backend to execute
207         backend.dispatchCompute(computeContext, args);
208     }
209 }