1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat;
 26 
 27 
 28 import hat.backend.Backend;
 29 import hat.buffer.Buffer;
 30 import hat.buffer.BufferAllocator;
 31 import hat.buffer.BufferTracker;
 32 import hat.ifacemapper.BoundSchema;
 33 import hat.ifacemapper.SegmentMapper;
 34 import hat.optools.LambdaOpWrapper;
 35 import hat.optools.OpWrapper;
 36 
 37 import java.lang.invoke.MethodHandles;
 38 import java.lang.reflect.Method;
 39 
 40 import jdk.incubator.code.Op;
 41 import jdk.incubator.code.Quotable;
 42 import jdk.incubator.code.Quoted;
 43 import jdk.incubator.code.dialect.core.CoreOp;
 44 import jdk.incubator.code.dialect.java.JavaOp;
 45 
 46 import java.util.HashMap;
 47 import java.util.Map;
 48 import java.util.ServiceLoader;
 49 import java.util.function.Consumer;
 50 import java.util.function.Predicate;
 51 
 52 import static hat.backend.Backend.FIRST;
 53 
 54 /**
 55  * This class provides the developer facing view of HAT, and wraps a <a href="backend/Backend.html">Backend</a> capable of
 56  * executing <b>NDRange</b> style execution.
 57  * <p/>
 58  * An Accelerator is provided a <a href="java/lang/invoke/MethodHandles.Lookup.html">MethodHandles.Lookup</a> with visibility to the
 59  * compute to be performed.
 60  * <p/>
 61  * As well we either a <a href="backend/Backend.html">Backend</a> directly
 62  * <pre>
 63  * Accelerator accelerator =
 64  *    new Accelerator(MethodHandles.lookup(),
 65  *       new JavaMultiThreadedBackend());
 66  * </pre>
 67  * or a {@code java.util.function.Predicate<Backend>} which can be used to select the required {@code Backend}
 68  * loaded via Javas ServiceLoader mechanism
 69  * {@code}
 70  * <pre>
 71  * Accelerator accelerator =
 72  *    new Accelerator(MethodHandles.lookup(),
 73  *        be -> be.name().startsWith("OpenCL));
 74  * </pre>}
 75  *
 76  * @author Gary Frost
 77  */
 78 public class Accelerator implements BufferAllocator, BufferTracker {
 79     public MethodHandles.Lookup lookup;
 80     public final Backend backend;
 81 
 82     private final Map<Method, hat.ComputeContext> cache = new HashMap<>();
 83 
 84     public NDRange range(int max) {
 85         NDRange ndRange = new NDRange(this);
 86         ndRange.kid = new KernelContext(ndRange, max);
 87         return ndRange;
 88     }
 89 
 90     public NDRange range(int maxX, int maxY) {
 91         NDRange ndRange = new NDRange(this);
 92         ndRange.kid = new KernelContext(ndRange, maxX, maxY);
 93         return ndRange;
 94     }
 95 
 96     public NDRange range(int maxX, int maxY, int maxZ) {
 97         NDRange ndRange = new NDRange(this);
 98         ndRange.kid = new KernelContext(ndRange, maxX, maxY, maxZ);
 99         return ndRange;
100     }
101 
102     public NDRange range(ComputeRange computeRange) {
103         NDRange ndRange = new NDRange(this);
104         ndRange.kid = new KernelContext(ndRange, computeRange);
105         return ndRange;
106     }
107 
108     protected Accelerator(MethodHandles.Lookup lookup, ServiceLoader.Provider<Backend> provider) {
109         this(lookup, provider.get());
110     }
111     public Accelerator(MethodHandles.Lookup lookup) {
112         this(lookup, FIRST);
113     }
114 
115     /**
116      * @param lookup
117      * @param backend
118      */
119     public Accelerator(MethodHandles.Lookup lookup, Backend backend) {
120         this.lookup = lookup;
121         this.backend = backend;
122     }
123 
124     /**
125      * @param lookup
126      * @param backendPredicate
127      */
128     public Accelerator(MethodHandles.Lookup lookup, Predicate<Backend> backendPredicate) {
129         this(lookup, Backend.getBackend(backendPredicate));
130     }
131 
132     @Override
133     public <T extends Buffer> T allocate(SegmentMapper<T> segmentMapper, BoundSchema<T> boundShema) {
134         return backend.allocate(segmentMapper, boundShema);
135     }
136 
137     @Override
138     public void preMutate(Buffer b) {
139         if (backend instanceof BufferTracker) {
140             ((BufferTracker) backend).preMutate(b);
141         }
142     }
143 
144     @Override
145     public void postMutate(Buffer b) {
146         if (backend instanceof BufferTracker) {
147             ((BufferTracker) backend).postMutate(b);
148         }
149     }
150 
151     @Override
152     public void preAccess(Buffer b) {
153         if (backend instanceof BufferTracker) {
154             ((BufferTracker) backend).preAccess(b);
155         }
156     }
157 
158     @Override
159     public void postAccess(Buffer b) {
160         if (backend instanceof BufferTracker) {
161             ((BufferTracker) backend).postAccess(b);
162         }
163     }
164 /*
165     @Override
166     public void preEscape(Buffer b) {
167         if (backend instanceof BufferTracker) {
168             ((BufferTracker) backend).preEscape(b);
169         }
170     }
171 
172     @Override
173     public void postEscape(Buffer b) {
174         if (backend instanceof BufferTracker) {
175             ((BufferTracker) backend).postEscape(b);
176         }
177     } */
178 
179     /**
180      * An interface used for wrapping the compute entrypoint of work to be performed by the Accelerator.
181      * <p/>
182      * So given a ComputeClass such as...
183      * <pre>
184      *  public class MyComputeClass {
185      *    @ CodeReflection
186      *    public static void addDeltaKernel(KernelContext kc, S32Array arrayOfInt, int delta) {
187      *        arrayOfInt.array(kc.x, arrayOfInt.array(kc.x)+delta);
188      *    }
189      *
190      *    @ CodeReflection
191      *    static public void doSomeWork(final ComputeContext cc, S32Array arrayOfInt) {
192      *    }
193      *  }
194      *  </pre>
195      * The accelerator will be passed the doSomeWork entrypoint, wrapped in a {@code QuotableComputeContextConsumer}
196      * <pre>
197      *  accelerator.compute(cc ->
198      *     MyCompute.doSomeWork(cc, arrayOfInt)
199      *  );
200      *  </pre>
201      */
202     public interface QuotableComputeContextConsumer extends Quotable, Consumer<ComputeContext> {
203     }
204 
205     /**
206      * This method provides the Accelerator with the {@code Compute Entrypoint} from a Compute class.
207      * <p>
208      * The entrypoint is wrapped in a <a href="QuotableComputeContextConsumer.html">QuotableComputeContextConsumer</a> lambda.
209      *
210      * <pre>
211      * accelerator.compute(cc -&gt;
212      *     MyCompute.doSomeWork(cc, intArray)
213      * )
214      * </pre>
215      */
216     public void compute(QuotableComputeContextConsumer quotableComputeContextConsumer) {
217         Quoted quoted = Op.ofQuotable(quotableComputeContextConsumer).orElseThrow();
218         LambdaOpWrapper lambda = OpWrapper.wrap(lookup,(JavaOp.LambdaOp) quoted.op());
219         Method method = lambda.getQuotableTargetMethod();
220 
221         // Create (or get cached) a compute context which closes over compute entryppint and reachable kernels.
222         // The models of all compute and kernel methods are passed to the backend during creation
223         // The backend may well mutate the models.
224         // It will also use this opportunity to generate ISA specific code for the kernels.
225         ComputeContext computeContext = cache.computeIfAbsent(method, (_) ->
226                 new ComputeContext(this, method)
227         );
228         // Here we get the captured values  from the Quotable
229         Object[] args = lambda.getQuotableCapturedValues(quoted, method);
230         args[0] = computeContext;
231 
232         // now ask the backend to execute
233         backend.dispatchCompute(computeContext, args);
234     }
235 }