1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package wrap.opencl;
 26 
 27 
 28 import opencl.opencl_h;
 29 import optkl.ifacemapper.Buffer;
 30 import optkl.ifacemapper.BufferState;
 31 import optkl.ifacemapper.MappableIface;
 32 import wrap.ArenaHolder;
 33 import wrap.Wrap;
 34 
 35 import java.lang.foreign.Arena;
 36 import java.lang.foreign.MemorySegment;
 37 import java.util.ArrayList;
 38 import java.util.List;
 39 
 40 import static java.lang.foreign.MemorySegment.NULL;
 41 import static opencl.opencl_h.CL_DEVICE_BUILT_IN_KERNELS;
 42 import static opencl.opencl_h.CL_DEVICE_MAX_COMPUTE_UNITS;
 43 import static opencl.opencl_h.CL_DEVICE_NAME;
 44 import static opencl.opencl_h.CL_DEVICE_TYPE_ALL;
 45 import static opencl.opencl_h.CL_DEVICE_VENDOR;
 46 import static opencl.opencl_h.CL_MEM_READ_WRITE;
 47 import static opencl.opencl_h.CL_MEM_USE_HOST_PTR;
 48 import static opencl.opencl_h.CL_PROGRAM_BUILD_LOG;
 49 import static opencl.opencl_h.CL_QUEUE_PROFILING_ENABLE;
 50 import static opencl.opencl_h.CL_SUCCESS;
 51 
 52 // https://streamhpc.com/blog/2013-04-28/opencl-error-codes/
 53 public class CLPlatform implements ArenaHolder {
 54     public static List<CLPlatform> platforms(Arena arena) {
 55         var arenaWrapper = ArenaHolder.wrap(arena);
 56         List<CLPlatform> platforms = new ArrayList<>();
 57         var platformc = arenaWrapper.intPtr(0);
 58         if ((opencl_h.clGetPlatformIDs(0, NULL, platformc.ptr())) != CL_SUCCESS()) {
 59             System.out.println("Failed to get opencl platforms");
 60         } else {
 61             var platformIds = arenaWrapper.ptrArr(platformc.get());
 62             if ((opencl_h.clGetPlatformIDs(platformc.get(), platformIds.ptr(), NULL)) != CL_SUCCESS()) {
 63                 System.out.println("Failed getting platform ids");
 64             } else {
 65                 for (int i = 0; i < platformc.get(); i++) {
 66                     platforms.add(new CLPlatform(arena, platformIds.get(i)));
 67                 }
 68             }
 69         }
 70         return platforms;
 71     }
 72 
 73     public static class CLDevice implements ArenaHolder {
 74         final CLPlatform platform;
 75         final MemorySegment deviceId;
 76 
 77         @Override
 78         public Arena arena() {
 79             return platform.arena();
 80         }
 81 
 82         int intDeviceInfo(int query) {
 83             var value = intPtr(0);
 84             if ((opencl_h.clGetDeviceInfo(deviceId, query, value.sizeof(), value.ptr(), NULL)) != CL_SUCCESS()) {
 85                 throw new RuntimeException("Failed to get query " + query);
 86             }
 87             return value.get();
 88         }
 89 
 90         String strDeviceInfo(int query) {
 91             var value = cstr(2048);
 92             if ((opencl_h.clGetDeviceInfo(deviceId, query, value.len(), value.ptr(), NULL)) != CL_SUCCESS()) {
 93                 throw new RuntimeException("Failed to get query " + query);
 94             }
 95             return value.get();
 96         }
 97 
 98         public int computeUnits() {
 99             return intDeviceInfo(CL_DEVICE_MAX_COMPUTE_UNITS());
100         }
101 
102         public String deviceName() {
103             return strDeviceInfo(CL_DEVICE_NAME());
104         }
105 
106         public String deviceVendor() {
107             return strDeviceInfo(CL_DEVICE_VENDOR());
108         }
109 
110         public String builtInKernels() {
111             return strDeviceInfo(CL_DEVICE_BUILT_IN_KERNELS());
112         }
113 
114         CLDevice(CLPlatform platform, MemorySegment deviceId) {
115             this.platform = platform;
116             this.deviceId = deviceId;
117         }
118 
119         public static class CLContext implements ArenaHolder {
120             CLDevice device;
121             MemorySegment context;
122             MemorySegment queue;
123 
124             @Override
125             public Arena arena() {
126                 return device.arena();
127             }
128 
129             CLContext(CLDevice device, MemorySegment context) {
130                 this.device = device;
131                 this.context = context;
132                 var status = device.platform.status;
133 
134                 var queue_props = CL_QUEUE_PROFILING_ENABLE();
135                 if ((this.queue = opencl_h.clCreateCommandQueue(context, device.deviceId, queue_props, status.ptr())) == NULL) {
136                     opencl_h.clReleaseContext(context);
137                 } else {
138                     if (!status.isOK()) {
139                         opencl_h.clReleaseContext(context);
140                     }
141                 }
142             }
143 
144             static public class CLProgram implements ArenaHolder {
145                 CLContext context;
146                 String source;
147                 MemorySegment program;
148                 String log;
149 
150                 @Override
151                 public Arena arena() {
152                     return context.arena();
153                 }
154 
155                 CLProgram(CLContext context, String source) {
156                     this.context = context;
157                     this.source = source;
158                     var sourceList = ptrArr(source);
159                     var status = context.device.platform.status;
160                     if ((program = opencl_h.clCreateProgramWithSource(context.context, 1, sourceList.ptr(), NULL, status.ptr())) == NULL) {
161                         if (!status.isOK()) {
162                             throw new RuntimeException("failed to createProgram " + status.get());
163                         }
164                         throw new RuntimeException("failed to createProgram");
165                     } else {
166                         var deviceIdList = ptrArr(context.device.deviceId);
167                         if ((status.set(opencl_h.clBuildProgram(program, 1, deviceIdList.ptr(), NULL, NULL, NULL))) != CL_SUCCESS()) {
168                             System.err.println("failed to build program " + status);
169                             System.err.println("Source "+source);
170                            // System.exit(1);
171                         }
172                         var logLen = longPtr(1L);
173                         if ((status.set(opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, CL_PROGRAM_BUILD_LOG(), 0, NULL, logLen.ptr()))) != CL_SUCCESS()) {
174                             System.err.println("failed to get log build " + status.get());
175                             System.exit(1);
176                         } else {
177                             var logPtr = cstr(1 + logLen.get());
178                             if ((status.set(opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), logLen.get(), logPtr.ptr(), NULL))) != opencl_h.CL_SUCCESS()) {
179                                 System.out.println("clGetBuildInfo (getting log) failed ");
180                             } else {
181                                 log = logPtr.get();
182                                 System.out.println("log\n" +log);
183                             }
184                         }
185                     }
186                 }
187 
188                 public static class CLKernel implements ArenaHolder {
189                     CLProgram program;
190                     MemorySegment kernel;
191                     String kernelName;
192 
193                     @Override
194                     public Arena arena() {
195                         return program.arena();
196                     }
197 
198                     public CLKernel(CLProgram program, String kernelName) {
199                         this.program = program;
200                         this.kernelName = kernelName;
201                         var kernelNameCStr = this.cstr(kernelName);
202                         var status = program.context.device.platform.status;
203                         kernel = opencl_h.clCreateKernel(program.program, kernelNameCStr.ptr(), status.ptr());
204                         if (!status.isOK()) {
205                             System.out.println("failed to create kernel '"+kernelName+"'" + status);
206                         }
207                     }
208 
209 
210                     public void run(CLWrapComputeContext clWrapComputeContext, int range, Object... args) {
211                         var status = CLStatusPtr.of(arena());
212                         for (int i = 0; i < args.length; i++) {
213                             if (args[i] instanceof CLWrapComputeContext.MemorySegmentState memorySegmentState) {
214                                 if (memorySegmentState.clMemPtr == null) {
215                                     memorySegmentState.clMemPtr = CLWrapComputeContext.ClMemPtr.of(arena(), opencl_h.clCreateBuffer(program.context.context,
216                                             CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
217                                             memorySegmentState.memorySegment.byteSize(),
218                                             memorySegmentState.memorySegment,
219                                             status.ptr()));
220                                     if (!status.isOK()) {
221                                         throw new RuntimeException("failed to create memory buffer for arg["+i+" " + status.get());
222                                     }
223                                 }
224                                 if (memorySegmentState.copyToDevice) {
225                                     status.set(opencl_h.clEnqueueWriteBuffer(program.context.queue,
226                                             memorySegmentState.clMemPtr.get(),
227                                             clWrapComputeContext.blockInt(),
228                                             0,
229                                             memorySegmentState.memorySegment.byteSize(),
230                                             memorySegmentState.memorySegment,
231                                             clWrapComputeContext.eventc(),
232                                             clWrapComputeContext.eventsPtr(),
233                                             clWrapComputeContext.nextEventPtrSlot()
234                                     ));
235                                     if (!status.isOK()) {
236                                         System.err.println("failed to enqueue write for arg["+i+" " + status);
237                                         System.exit(1);
238                                     }
239                                 }
240 
241                                 status.set(opencl_h.clSetKernelArg(kernel, i, memorySegmentState.clMemPtr.sizeof(), memorySegmentState.clMemPtr.ptr()));
242                                 if (!status.isOK()) {
243                                     System.err.println("failed to set arg["+i+" " + status);
244                                     System.exit(1);
245                                 }
246                             } else if (args[i] instanceof Buffer buffer) {
247                                 //  System.out.println("Arg "+i+" is a buffer so checking if we need to write");
248                                 BufferState bufferState = BufferState.of(buffer);
249 
250                                 //System.out.println("Before possible write"+ bufferState);
251                                 MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
252 
253                                 CLWrapComputeContext.ClMemPtr clmem = clWrapComputeContext.clMemMap.computeIfAbsent(memorySegment, k ->
254                                         CLWrapComputeContext.ClMemPtr.of(arena(), opencl_h.clCreateBuffer(program.context.context,
255                                                 CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
256                                                 memorySegment.byteSize(),
257                                                 memorySegment,
258                                                 status.ptr()))
259                                 );
260                                 if (bufferState.getState()==BufferState.HOST_OWNED) {
261 
262                                     //System.out.println("arg " + args[i] + " isHostDirty copying in");
263                                     status.set(opencl_h.clEnqueueWriteBuffer(program.context.queue,
264                                             clmem.get(),
265                                             clWrapComputeContext.blockInt(),
266                                             0,
267                                             memorySegment.byteSize(),
268                                             memorySegment,
269                                             clWrapComputeContext.eventc(),
270                                             clWrapComputeContext.eventsPtr(),
271                                             clWrapComputeContext.nextEventPtrSlot()
272                                     ));
273                                     if (!status.isOK()) {
274                                         System.err.println("failed to enqueue write for arg["+i+" " + status);
275                                         System.exit(1);
276                                     }
277                                 } else {
278 
279                                     //  System.out.println("arg "+args[i]+" is not HostDirty not copying in");
280                                 }
281                                 //     System.out.println("After possible write "+ bufferState);
282                                 status.set(opencl_h.clSetKernelArg(kernel, i, clmem.sizeof(), clmem.ptr()));
283                                 if (!status.isOK()) {
284                                     System.err.println("failed to set arg["+i+"]" + status);
285                                     System.exit(1);
286                                 }
287 
288                             } else {
289                                 Wrap.Ptr ptr = switch (args[i]) {
290                                     case Integer intArg -> intPtr(intArg);
291                                     case Float floatArg -> floatPtr(floatArg);
292                                     case Double doubleArg -> doublePtr(doubleArg);
293                                     case Long longArg -> longPtr(longArg);
294                                     case Short shortArg -> shortPtr(shortArg);
295                                     default -> throw new IllegalStateException("Unexpected value: " + args[i]);
296                                 };
297                                 status.set(opencl_h.clSetKernelArg(kernel, i, ptr.sizeof(), ptr.ptr()));
298                                 if (!status.isOK()) {
299                                     System.err.println("failed to set arg["+i+"] " + status);
300 
301                                     System.exit(1);
302                                 }
303 
304                             }
305                         }
306 
307                         // We need to store x,y,z sizes so this is a kind of int3
308                         var globalSize = this.ofInts(range, 0, 0);
309                         status.set(opencl_h.clEnqueueNDRangeKernel(
310                                         program.context.queue,
311                                         kernel,
312                                         1, // this must match the # of dims we are using in this case 1 of 3
313                                         NULL,
314                                         globalSize.ptr(),
315                                         NULL,
316                                         clWrapComputeContext.eventc(),
317                                         clWrapComputeContext.eventsPtr(),
318                                         clWrapComputeContext.nextEventPtrSlot()
319                                 )
320                         );
321                         if (!status.isOK()) {
322                             System.out.println("failed to enqueue NDRange " + status);
323                         }
324 
325                         if (clWrapComputeContext.alwaysBlock) {
326                             opencl_h.clFlush(program.context.queue);
327                         }
328 
329                         for (int i = 0; i < args.length; i++) {
330                             if (args[i] instanceof CLWrapComputeContext.MemorySegmentState memorySegmentState) {
331                                 if (memorySegmentState.copyFromDevice) {
332                                     status.set(opencl_h.clEnqueueReadBuffer(program.context.queue,
333                                             memorySegmentState.clMemPtr.get(),
334                                             clWrapComputeContext.blockInt(),
335                                             0,
336                                             memorySegmentState.memorySegment.byteSize(),
337                                             memorySegmentState.memorySegment,
338                                             clWrapComputeContext.eventc(),
339                                             clWrapComputeContext.eventsPtr(),
340                                             clWrapComputeContext.nextEventPtrSlot()
341                                     ));
342                                     if (!status.isOK()) {
343                                         System.out.println("failed to enqueue read " + status);
344                                     }
345                                 }
346                             } else if (args[i] instanceof Buffer buffer) {
347                                 //   System.out.println("Arg "+i+" is a buffer so checking if we need to read");
348                                 BufferState bufferState = BufferState.of(buffer);
349                                 MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
350                                 CLWrapComputeContext.ClMemPtr clmem = clWrapComputeContext.clMemMap.get(memorySegment);
351                                 // System.out.println("Before possible read "+ bufferState);
352                                 if (bufferState.getState() == BufferState.HOST_OWNED) {
353                                   //  System.out.println("arg " + args[i] + " isDeviceDirty copying out");
354                                     status.set(opencl_h.clEnqueueReadBuffer(program.context.queue,
355                                             clmem.get(),
356                                             clWrapComputeContext.blockInt(),
357                                             0,
358                                             memorySegment.byteSize(),
359                                             memorySegment,
360                                             clWrapComputeContext.eventc(),
361                                             clWrapComputeContext.eventsPtr(),
362                                             clWrapComputeContext.nextEventPtrSlot()
363                                     ));
364                                     if (!status.isOK()) {
365                                         System.out.println("failed to enqueue read " + status);
366                                     }
367                                 } else {
368                                     //   System.out.println("arg "+args[i]+" isnot DeviceDirty not copying out");
369                                 }
370 
371                             }
372                         }
373                         // if (!computeContext.alwaysBlock) {
374                         clWrapComputeContext.waitForEvents();
375                         //  }
376                     }
377                 }
378 
379                 public CLKernel getKernel(String kernelName) {
380                     return new CLKernel(this, kernelName);
381                 }
382             }
383 
384             public CLProgram buildProgram(String source) {
385                 var program = new CLProgram(this, source);
386                 return program;
387             }
388         }
389 
390         public CLContext createContext() {
391             var status = platform.status;
392             MemorySegment context;
393             var deviceIds = ptrArr(this.deviceId);
394             if ((context = opencl_h.clCreateContext(NULL, 1, deviceIds.ptr(), NULL, NULL, status.ptr())) == NULL) {
395                 System.out.println("Failed to get context  ");
396                 return null;
397             } else {
398                 if (!status.isOK()) {
399                     System.out.println("failed to get context  " + status);
400                 }
401                 return new CLContext(this, context);
402             }
403         }
404     }
405 
406     int intPlatformInfo(int query) {
407         var value = intPtr(0);
408         if ((opencl_h.clGetPlatformInfo(platformId, query, value.sizeof(), value.ptr(), NULL)) != opencl_h.CL_SUCCESS()) {
409             throw new RuntimeException("Failed to get query " + query);
410         }
411         return value.get();
412     }
413 
414     String strPlatformInfo(int query) {
415 
416         var value = cstr(2048);
417         int status;
418         if ((status = opencl_h.clGetPlatformInfo(platformId, query, value.len(), value.ptr(), NULL)) != opencl_h.CL_SUCCESS()) {
419             throw new RuntimeException("Failed to get query " + query);
420         }
421         return value.get();
422     }
423 
424     private Arena secretarena;
425     MemorySegment platformId;
426     public List<CLDevice> devices = new ArrayList<>();
427     final CLStatusPtr status;
428 
429     public String platformName() {
430         return strPlatformInfo(opencl_h.CL_PLATFORM_NAME());
431     }
432 
433     String vendorName() {
434         return strPlatformInfo(opencl_h.CL_PLATFORM_VENDOR());
435     }
436 
437     String version() {
438         return strPlatformInfo(opencl_h.CL_PLATFORM_VERSION());
439     }
440 
441     @Override
442     public Arena arena() {
443         return secretarena;
444     }
445 
446     public CLPlatform(Arena arena, MemorySegment platformId) {
447         this.secretarena = arena;
448         this.platformId = platformId;
449         this.status = CLStatusPtr.of(arena());
450         var devicec = intPtr(0);
451         if ((status.set(opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), 0, NULL, devicec.ptr()))) != opencl_h.CL_SUCCESS()) {
452             System.err.println("Failed getting devicec for platform 0 ");
453         } else {
454             //  System.out.println("platform 0 has " + devicec + " device" + ((devicec > 1) ? "s" : ""));
455             var deviceIdList = ptrArr(devicec.get());
456             if ((status.set(opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), devicec.get(), deviceIdList.ptr(), devicec.ptr()))) != opencl_h.CL_SUCCESS()) {
457                 System.err.println("Failed getting deviceids  for platform 0 ");
458             } else {
459                 // System.out.println("We have "+devicec+" device ids");
460                 for (int i = 0; i < devicec.get(); i++) {
461                     devices.add(new CLDevice(this, deviceIdList.get(i)));
462                 }
463             }
464         }
465     }
466 }