1 package hat.backend.jextracted;
  2 /*
  3  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  4  *
  5  * Redistribution and use in source and binary forms, with or without
  6  * modification, are permitted provided that the following conditions
  7  * are met:
  8  *
  9  *   - Redistributions of source code must retain the above copyright
 10  *     notice, this list of conditions and the following disclaimer.
 11  *
 12  *   - Redistributions in binary form must reproduce the above copyright
 13  *     notice, this list of conditions and the following disclaimer in the
 14  *     documentation and/or other materials provided with the distribution.
 15  *
 16  *   - Neither the name of Oracle nor the names of its
 17  *     contributors may be used to endorse or promote products derived
 18  *     from this software without specific prior written permission.
 19  *
 20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
 21  * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 22  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 23  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
 24  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 25  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 26  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 27  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 28  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 29  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 30  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 31  */
 32 
 33 import opencl.opencl_h;
 34 
 35 import java.io.IOException;
 36 import java.lang.foreign.Arena;
 37 import java.lang.foreign.MemorySegment;
 38 import java.lang.foreign.ValueLayout;
 39 import java.util.ArrayList;
 40 import java.util.List;
 41 
 42 //import static java.lang.foreign.ValueLayout.JAVA_INT;
 43 import static opencl.opencl_h.CL_DEVICE_TYPE_ALL;
 44 import static opencl.opencl_h.CL_MEM_READ_WRITE;
 45 import static opencl.opencl_h.CL_MEM_USE_HOST_PTR;
 46 import static opencl.opencl_h.CL_QUEUE_PROFILING_ENABLE;
 47 
 48 public class CLWrap {
 49     public static MemorySegment NULL = MemorySegment.NULL;
 50 
 51     // https://streamhpc.com/blog/2013-04-28/opencl-error-codes/
 52     static class Platform {
 53         static class Device {
 54             final Platform platform;
 55             final MemorySegment deviceId;
 56 
 57             int intDeviceInfo(int query) {
 58                 var value = 0;
 59                 if ((opencl_h.clGetDeviceInfo(deviceId, query, opencl_h.C_INT.byteSize(), platform.intValuePtr, NULL)) != opencl_h.CL_SUCCESS()) {
 60                     System.out.println("Failed to get query " + query);
 61                 } else {
 62                     value = platform.intValuePtr.get(opencl_h.C_INT, 0);
 63                 }
 64                 return value;
 65             }
 66 
 67             String strDeviceInfo(int query) {
 68                 String value = null;
 69                 if ((opencl_h.clGetDeviceInfo(deviceId, query, 2048, platform.byte2048ValuePtr, platform.intValuePtr)) != opencl_h.CL_SUCCESS()) {
 70                     System.out.println("Failed to get query " + query);
 71                 } else {
 72                     int len = platform.intValuePtr.get(opencl_h.C_INT, 0);
 73                     byte[] bytes = platform.byte2048ValuePtr.toArray(ValueLayout.JAVA_BYTE);
 74                     value = new String(bytes).substring(0, len - 1);
 75                 }
 76                 return value;
 77             }
 78 
 79             int computeUnits() {
 80                 return intDeviceInfo(opencl_h.CL_DEVICE_MAX_COMPUTE_UNITS());
 81             }
 82 
 83             String deviceName() {
 84                 return strDeviceInfo(opencl_h.CL_DEVICE_NAME());
 85             }
 86 
 87             String builtInKernels() {
 88                 return strDeviceInfo(opencl_h.CL_DEVICE_BUILT_IN_KERNELS());
 89             }
 90 
 91             Device(Platform platform, MemorySegment deviceId) {
 92                 this.platform = platform;
 93                 this.deviceId = deviceId;
 94             }
 95 
 96             public static class Context {
 97                 Device device;
 98                 MemorySegment context;
 99                 MemorySegment queue;
100 
101                 Context(Device device, MemorySegment context) {
102                     this.device = device;
103                     this.context = context;
104                     var statusPtr = device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, 1);
105 
106                     var queue_props = CL_QUEUE_PROFILING_ENABLE();
107                     if ((this.queue = opencl_h.clCreateCommandQueue(context, device.deviceId, queue_props, statusPtr)) == NULL) {
108                         int status = statusPtr.get(opencl_h.C_INT, 0);
109                         opencl_h.clReleaseContext(context);
110                         // delete[] platforms;
111                         // delete[] device_ids;
112                         return;
113                     }
114 
115                 }
116 
117                 static public class Program {
118                     Context context;
119                     String source;
120                     MemorySegment program;
121                     String log;
122 
123                     Program(Context context, String source) {
124                         this.context = context;
125                         this.source = source;
126                         MemorySegment sourcePtr = context.device.platform.openCL.arena.allocateFrom(source);
127                         var sourcePtrPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, sourcePtr);
128                     //    sourcePtrPtr.set(opencl_h.C_POINTER, 0, sourcePtr);
129                         var sourceLenPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_LONG,  source.length());
130                     //    sourceLenPtr.set(opencl_h.C_LONG, 0, source.length());
131                         var statusPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, 0);
132                         if ((program = opencl_h.clCreateProgramWithSource(context.context, 1, sourcePtrPtr, sourceLenPtr, statusPtr)) == NULL) {
133                             int status = statusPtr.get(opencl_h.C_INT, 0);
134                             if (status != opencl_h.CL_SUCCESS()) {
135                                 System.out.println("failed to createProgram " + status);
136                             }
137                             System.out.println("failed to createProgram");
138                         } else {
139                             int status = statusPtr.get(opencl_h.C_INT, 0);
140                             if (status != opencl_h.CL_SUCCESS()) {
141                                 System.out.println("failed to create program " + status);
142                             }
143                             var deviceIdPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, context.device.deviceId);
144                           //  deviceIdPtr.set(opencl_h.C_POINTER, 0, context.device.deviceId);
145                             if ((status = opencl_h.clBuildProgram(program, 1, deviceIdPtr, NULL, NULL, NULL)) != opencl_h.CL_SUCCESS()) {
146                                 System.out.println("failed to build" + status);
147                                 // dont return we may still be able to get log!
148                             }
149 
150                             var logLenPtr = context.device.platform.openCL.arena.allocate(opencl_h.C_LONG, 1);
151 
152                             if ((status = opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), 0, NULL, logLenPtr)) != opencl_h.CL_SUCCESS()) {
153                                 System.out.println("failed to get log build " + status);
154                             } else {
155                                 long logLen = logLenPtr.get(opencl_h.C_LONG, 0);
156                                 var logPtr = context.device.platform.openCL.arena.allocate(opencl_h.C_CHAR, 1 + logLen);
157                                 if ((status = opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), logLen, logPtr, logLenPtr)) != opencl_h.CL_SUCCESS()) {
158                                     System.out.println("clGetBuildInfo (getting log) failed");
159                                 } else {
160                                     byte[] bytes = logPtr.toArray(ValueLayout.JAVA_BYTE);
161                                     log = new String(bytes).substring(0, (int) logLen);
162                                 }
163                             }
164                         }
165                     }
166 
167                     public static class Kernel {
168                         Program program;
169                         MemorySegment kernel;
170                         String kernelName;
171 
172                         public Kernel(Program program, String kernelName) {
173                             this.program = program;
174                             this.kernelName = kernelName;
175                             var statusPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, opencl_h.CL_SUCCESS());
176                             MemorySegment kernelNamePtr = program.context.device.platform.openCL.arena.allocateFrom(kernelName);
177                             kernel = opencl_h.clCreateKernel(program.program, kernelNamePtr, statusPtr);
178                             int status = statusPtr.get(opencl_h.C_INT, 0);
179                             if (status != opencl_h.CL_SUCCESS()) {
180                                 System.out.println("failed to create kernel " + status);
181                             }
182                         }
183 
184                         public void run(int range, Object... args) {
185                             var bufPtr = program.context.device.platform.openCL.arena.allocate(opencl_h.cl_mem, args.length);
186                             var statusPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, opencl_h.CL_SUCCESS());
187                             int status;
188                             var eventMax = args.length * 4 + 1;
189                             int eventc = 0;
190                             var eventsPtr = program.context.device.platform.openCL.arena.allocate(opencl_h.cl_event, eventMax);
191                             boolean block = false;// true;
192                             for (int i = 0; i < args.length; i++) {
193                                 if (args[i] instanceof MemorySegment memorySegment) {
194                                     MemorySegment clMem = opencl_h.clCreateBuffer(program.context.context,
195                                             CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
196                                             memorySegment.byteSize(),
197                                             memorySegment,
198                                             statusPtr);
199                                     status = statusPtr.get(opencl_h.C_INT, 0);
200                                     if (status != opencl_h.CL_SUCCESS()) {
201                                         System.out.println("failed to create memory buffer " + status);
202                                     }
203                                     bufPtr.set(opencl_h.cl_mem, i * opencl_h.cl_mem.byteSize(), clMem);
204                                     status = opencl_h.clEnqueueWriteBuffer(program.context.queue,
205                                             clMem,
206                                             block ? opencl_h.CL_TRUE() : opencl_h.CL_FALSE(), //block?
207                                             0,
208                                             memorySegment.byteSize(),
209                                             memorySegment,
210                                             block ? 0 : eventc,
211                                             block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
212                                             block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event)
213                                     );
214                                     if (status != opencl_h.CL_SUCCESS()) {
215                                         System.out.println("failed to enqueue write " + status);
216                                     }
217                                     if (!block) {
218                                         eventc++;
219                                     }
220                                     var clMemPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, clMem);
221 
222                                     status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_POINTER.byteSize(), clMemPtr);
223                                     if (status != opencl_h.CL_SUCCESS()) {
224                                         System.out.println("failed to set arg " + status);
225                                     }
226                                 } else {
227                                     bufPtr.set(opencl_h.cl_mem, i * opencl_h.cl_mem.byteSize(), NULL);
228                                     switch (args[i]){
229                                         case Integer intArg->{
230                                             var intPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, intArg);
231                                             status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_INT.byteSize(), intPtr);
232                                             if (status != opencl_h.CL_SUCCESS()) {
233                                                 System.out.println("failed to set arg " + status);
234                                             }
235                                         }
236                                         case Float floatArg->{
237                                             var floatPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_FLOAT, floatArg);
238                                             status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_FLOAT.byteSize(), floatPtr);
239                                             if (status != opencl_h.CL_SUCCESS()) {
240                                                 System.out.println("failed to set arg " + status);
241                                             }
242                                         }
243                                         default -> throw new IllegalStateException("Unexpected value: " + args[i]);
244                                     }
245                                 }
246                             }
247 
248                             // We need to store x,y,z sizes so this is a kind of int3
249                             var globalSizePtr = program.context.device.platform.openCL.arena.allocate(opencl_h.C_INT, 3);
250                             globalSizePtr.set(opencl_h.C_INT, 0, range);
251                             globalSizePtr.set(opencl_h.C_INT, 1*opencl_h.C_INT.byteSize(), 0);
252                             globalSizePtr.set(opencl_h.C_INT, 2*opencl_h.C_INT.byteSize(), 0);
253                             status = opencl_h.clEnqueueNDRangeKernel(
254                                     program.context.queue,
255                                     kernel,
256                                     1, // this must match the # of dims we are using in this case 1 of 3
257                                     NULL,
258                                     globalSizePtr,
259                                     NULL,
260                                     block ? 0 : eventc,
261                                     block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
262                                     block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event
263                                     )
264                             );
265                             if (status != opencl_h.CL_SUCCESS()) {
266                                 System.out.println("failed to enqueue NDRange " + status);
267                             }
268 
269                             if (block) {
270                                 opencl_h.clFlush(program.context.queue);
271                             } else {
272                                 eventc++;
273                                 status = opencl_h.clWaitForEvents(eventc, eventsPtr);
274                                 if (status != opencl_h.CL_SUCCESS()) {
275                                     System.out.println("failed to wait for ndrange events " + status);
276                                 }
277                             }
278 
279                             for (int i = 0; i < args.length; i++) {
280                                 if (args[i] instanceof MemorySegment memorySegment) {
281                                     MemorySegment clMem = bufPtr.get(opencl_h.cl_mem, (long) i * opencl_h.cl_mem.byteSize());
282                                     status = opencl_h.clEnqueueReadBuffer(program.context.queue,
283                                             clMem,
284                                             block ? opencl_h.CL_TRUE() : opencl_h.CL_FALSE(),
285                                             0,
286                                             memorySegment.byteSize(),
287                                             memorySegment,
288                                             block ? 0 : eventc,
289                                             block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
290                                             block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event)// block?NULL:readEventPtr
291                                     );
292                                     if (status != opencl_h.CL_SUCCESS()) {
293                                         System.out.println("failed to enqueue read " + status);
294                                     }
295                                     if (!block) {
296                                         eventc++;
297                                     }
298                                 }
299                             }
300                             if (!block) {
301                                 status = opencl_h.clWaitForEvents(eventc, eventsPtr);
302                                 if (status != opencl_h.CL_SUCCESS()) {
303                                     System.out.println("failed to wait for events " + status);
304                                 }
305                             }
306                             for (int i = 0; i < args.length; i++) {
307                                 if (args[i] instanceof MemorySegment memorySegment) {
308                                     MemorySegment clMem = bufPtr.get(opencl_h.cl_mem, (long) i * opencl_h.cl_mem.byteSize());
309                                     status = opencl_h.clReleaseMemObject(clMem);
310                                     if (status != opencl_h.CL_SUCCESS()) {
311                                         System.out.println("failed to release memObject " + status);
312                                     }
313                                 }
314                             }
315                         }
316                     }
317 
318                     public Kernel getKernel(String kernelName) {
319                         return new Kernel(this, kernelName);
320                     }
321                 }
322 
323                 public Program buildProgram(String source) {
324                     var program = new Program(this, source);
325                     return program;
326                 }
327             }
328 
329             public Context createContext() {
330 
331                 var statusPtr = platform.openCL.arena.allocateFrom(opencl_h.C_INT, 0);
332                 MemorySegment context;
333                 var deviceIds = platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, this.deviceId);
334                 if ((context = opencl_h.clCreateContext(NULL, 1, deviceIds, NULL, NULL, statusPtr)) == NULL) {
335                     int status = statusPtr.get(opencl_h.C_INT, 0);
336                     System.out.println("Failed to get context  ");
337                     return null;
338                 } else {
339                     int status = statusPtr.get(opencl_h.C_INT, 0);
340                     if (status != opencl_h.CL_SUCCESS()) {
341                         System.out.println("failed to get context  " + status);
342                     }
343                     return new Context(this, context);
344                 }
345             }
346         }
347 
348         int intPlatformInfo(int query) {
349             var value = 0;
350             if ((opencl_h.clGetPlatformInfo(platformId, query, opencl_h.C_INT.byteSize(), intValuePtr, NULL)) != opencl_h.CL_SUCCESS()) {
351                 System.out.println("Failed to get query " + query);
352             } else {
353                 value = intValuePtr.get(opencl_h.C_INT, 0);
354             }
355             return value;
356         }
357 
358         String strPlatformInfo(int query) {
359             String value = null;
360             int status;
361             if ((status = opencl_h.clGetPlatformInfo(platformId, query, 2048, byte2048ValuePtr, intValuePtr)) != opencl_h.CL_SUCCESS()) {
362                 System.err.println("Failed to get query " + query);
363             } else {
364                 int len = intValuePtr.get(opencl_h.C_INT, 0);
365                 byte[] bytes = byte2048ValuePtr.toArray(ValueLayout.JAVA_BYTE);
366                 value = new String(bytes).substring(0, len - 1);
367             }
368             return value;
369         }
370 
371         CLWrap openCL;
372         MemorySegment platformId;
373         List<Device> devices = new ArrayList<>();
374         final MemorySegment intValuePtr;
375         final MemorySegment byte2048ValuePtr;
376 
377         String platformName() {
378             return strPlatformInfo(opencl_h.CL_PLATFORM_NAME());
379         }
380 
381         String vendorName() {
382             return strPlatformInfo(opencl_h.CL_PLATFORM_VENDOR());
383         }
384 
385         String version() {
386             return strPlatformInfo(opencl_h.CL_PLATFORM_VERSION());
387         }
388 
389         public Platform(CLWrap openCL, MemorySegment platformId) {
390             this.openCL = openCL;
391             this.platformId = platformId;
392             this.intValuePtr = openCL.arena.allocateFrom(opencl_h.C_INT, 0);
393             this.byte2048ValuePtr = openCL.arena.allocate(opencl_h.C_CHAR, 2048);
394             var devicecPtr = openCL.arena.allocateFrom(opencl_h.C_INT, 0);
395             int status;
396             if ((status = opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), 0, NULL, devicecPtr)) != opencl_h.CL_SUCCESS()) {
397                 System.err.println("Failed getting devicec for platform 0 ");
398             } else {
399                 int devicec = devicecPtr.get(opencl_h.C_INT, 0);
400                 //  System.out.println("platform 0 has " + devicec + " device" + ((devicec > 1) ? "s" : ""));
401                 var deviceIdsPtr = openCL.arena.allocate(opencl_h.C_POINTER, devicec);
402                 if ((status = opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), devicec, deviceIdsPtr, devicecPtr)) != opencl_h.CL_SUCCESS()) {
403                     System.err.println("Failed getting deviceids  for platform 0 ");
404                 } else {
405                     // System.out.println("We have "+devicec+" device ids");
406                     for (int i = 0; i < devicec; i++) {
407                         devices.add(new Device(this, deviceIdsPtr.get(opencl_h.C_POINTER, i * opencl_h.C_POINTER.byteSize())));
408                     }
409                 }
410             }
411         }
412     }
413 
414     List<Platform> platforms = new ArrayList<>();
415 
416     Arena arena;
417 
418     CLWrap(Arena arena) {
419         this.arena = arena;
420         var platformcPtr = arena.allocateFrom(opencl_h.C_INT, 0);
421 
422         if ((opencl_h.clGetPlatformIDs(0, NULL, platformcPtr)) != opencl_h.CL_SUCCESS()) {
423             System.out.println("Failed to get opencl platforms");
424         } else {
425             int platformc = platformcPtr.get(opencl_h.C_INT, 0);
426             // System.out.println("There are "+platformc+" platforms");
427             var platformIdsPtr = arena.allocate(opencl_h.C_POINTER, platformc);
428             if ((opencl_h.clGetPlatformIDs(platformc, platformIdsPtr, platformcPtr)) != opencl_h.CL_SUCCESS()) {
429                 System.out.println("Failed getting platform ids");
430             } else {
431                 for (int i = 0; i < platformc; i++) {
432                     // System.out.println("We should have the ids");
433                     platforms.add(new Platform(this, platformIdsPtr.get(opencl_h.C_POINTER, i)));
434                 }
435             }
436         }
437     }
438 
439 
440     public static void main(String[] args) throws IOException {
441         try (var arena = Arena.ofConfined()) {
442             CLWrap openCL = new CLWrap(arena);
443 
444             Platform.Device[] selectedDevice = new Platform.Device[1];
445             openCL.platforms.forEach(platform -> {
446                 System.out.println("Platform Name " + platform.platformName());
447                 platform.devices.forEach(device -> {
448                     System.out.println("   Compute Units     " + device.computeUnits());
449                     System.out.println("   Device Name       " + device.deviceName());
450                     System.out.println("   Built In Kernels  " + device.builtInKernels());
451                     selectedDevice[0] = device;
452                 });
453             });
454             var context = selectedDevice[0].createContext();
455             var program = context.buildProgram("""
456                     __kernel void squares(__global int* in,__global int* out ){
457                         int gid = get_global_id(0);
458                         out[gid] = in[gid]*in[gid];
459                     }
460                     """);
461             var kernel = program.getKernel("squares");
462             var in = arena.allocate(opencl_h.C_INT, 512);
463             var out = arena.allocate(opencl_h.C_INT, 512);
464             for (int i = 0; i < 512; i++) {
465                 in.set(opencl_h.C_INT, (int) i * opencl_h.C_INT.byteSize(), i);
466             }
467             kernel.run(512, in, out);
468             for (int i = 0; i < 512; i++) {
469                 System.out.println(i + " " + out.get(opencl_h.C_INT, (int) i * opencl_h.C_INT.byteSize()));
470             }
471         }
472     }
473 }