1 package hat.backend.jextracted;
2 /*
3 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 *
9 * - Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *
12 * - Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in the
14 * documentation and/or other materials provided with the distribution.
15 *
16 * - Neither the name of Oracle nor the names of its
17 * contributors may be used to endorse or promote products derived
18 * from this software without specific prior written permission.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21 * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
22 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
23 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
24 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 import opencl.opencl_h;
34
35 import java.io.IOException;
36 import java.lang.foreign.Arena;
37 import java.lang.foreign.MemorySegment;
38 import java.lang.foreign.ValueLayout;
39 import java.util.ArrayList;
40 import java.util.List;
41
42 //import static java.lang.foreign.ValueLayout.JAVA_INT;
43 import static opencl.opencl_h.CL_DEVICE_TYPE_ALL;
44 import static opencl.opencl_h.CL_MEM_READ_WRITE;
45 import static opencl.opencl_h.CL_MEM_USE_HOST_PTR;
46 import static opencl.opencl_h.CL_QUEUE_PROFILING_ENABLE;
47
48 public class CLWrap {
49 public static MemorySegment NULL = MemorySegment.NULL;
50
51 // https://streamhpc.com/blog/2013-04-28/opencl-error-codes/
52 static class Platform {
53 static class Device {
54 final Platform platform;
55 final MemorySegment deviceId;
56
57 int intDeviceInfo(int query) {
58 var value = 0;
59 if ((opencl_h.clGetDeviceInfo(deviceId, query, opencl_h.C_INT.byteSize(), platform.intValuePtr, NULL)) != opencl_h.CL_SUCCESS()) {
60 System.out.println("Failed to get query " + query);
61 } else {
62 value = platform.intValuePtr.get(opencl_h.C_INT, 0);
63 }
64 return value;
65 }
66
67 String strDeviceInfo(int query) {
68 String value = null;
69 if ((opencl_h.clGetDeviceInfo(deviceId, query, 2048, platform.byte2048ValuePtr, platform.intValuePtr)) != opencl_h.CL_SUCCESS()) {
70 System.out.println("Failed to get query " + query);
71 } else {
72 int len = platform.intValuePtr.get(opencl_h.C_INT, 0);
73 byte[] bytes = platform.byte2048ValuePtr.toArray(ValueLayout.JAVA_BYTE);
74 value = new String(bytes).substring(0, len - 1);
75 }
76 return value;
77 }
78
79 int computeUnits() {
80 return intDeviceInfo(opencl_h.CL_DEVICE_MAX_COMPUTE_UNITS());
81 }
82
83 String deviceName() {
84 return strDeviceInfo(opencl_h.CL_DEVICE_NAME());
85 }
86
87 String builtInKernels() {
88 return strDeviceInfo(opencl_h.CL_DEVICE_BUILT_IN_KERNELS());
89 }
90
91 Device(Platform platform, MemorySegment deviceId) {
92 this.platform = platform;
93 this.deviceId = deviceId;
94 }
95
96 public static class Context {
97 Device device;
98 MemorySegment context;
99 MemorySegment queue;
100
101 Context(Device device, MemorySegment context) {
102 this.device = device;
103 this.context = context;
104 var statusPtr = device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, 1);
105
106 var queue_props = CL_QUEUE_PROFILING_ENABLE();
107 if ((this.queue = opencl_h.clCreateCommandQueue(context, device.deviceId, queue_props, statusPtr)) == NULL) {
108 int status = statusPtr.get(opencl_h.C_INT, 0);
109 opencl_h.clReleaseContext(context);
110 // delete[] platforms;
111 // delete[] device_ids;
112 return;
113 }
114
115 }
116
117 static public class Program {
118 Context context;
119 String source;
120 MemorySegment program;
121 String log;
122
123 Program(Context context, String source) {
124 this.context = context;
125 this.source = source;
126 MemorySegment sourcePtr = context.device.platform.openCL.arena.allocateFrom(source);
127 var sourcePtrPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, sourcePtr);
128 // sourcePtrPtr.set(opencl_h.C_POINTER, 0, sourcePtr);
129 var sourceLenPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_LONG, source.length());
130 // sourceLenPtr.set(opencl_h.C_LONG, 0, source.length());
131 var statusPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, 0);
132 if ((program = opencl_h.clCreateProgramWithSource(context.context, 1, sourcePtrPtr, sourceLenPtr, statusPtr)) == NULL) {
133 int status = statusPtr.get(opencl_h.C_INT, 0);
134 if (status != opencl_h.CL_SUCCESS()) {
135 System.out.println("failed to createProgram " + status);
136 }
137 System.out.println("failed to createProgram");
138 } else {
139 int status = statusPtr.get(opencl_h.C_INT, 0);
140 if (status != opencl_h.CL_SUCCESS()) {
141 System.out.println("failed to create program " + status);
142 }
143 var deviceIdPtr = context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, context.device.deviceId);
144 // deviceIdPtr.set(opencl_h.C_POINTER, 0, context.device.deviceId);
145 if ((status = opencl_h.clBuildProgram(program, 1, deviceIdPtr, NULL, NULL, NULL)) != opencl_h.CL_SUCCESS()) {
146 System.out.println("failed to build" + status);
147 // dont return we may still be able to get log!
148 }
149
150 var logLenPtr = context.device.platform.openCL.arena.allocate(opencl_h.C_LONG, 1);
151
152 if ((status = opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), 0, NULL, logLenPtr)) != opencl_h.CL_SUCCESS()) {
153 System.out.println("failed to get log build " + status);
154 } else {
155 long logLen = logLenPtr.get(opencl_h.C_LONG, 0);
156 var logPtr = context.device.platform.openCL.arena.allocate(opencl_h.C_CHAR, 1 + logLen);
157 if ((status = opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), logLen, logPtr, logLenPtr)) != opencl_h.CL_SUCCESS()) {
158 System.out.println("clGetBuildInfo (getting log) failed");
159 } else {
160 byte[] bytes = logPtr.toArray(ValueLayout.JAVA_BYTE);
161 log = new String(bytes).substring(0, (int) logLen);
162 }
163 }
164 }
165 }
166
167 public static class Kernel {
168 Program program;
169 MemorySegment kernel;
170 String kernelName;
171
172 public Kernel(Program program, String kernelName) {
173 this.program = program;
174 this.kernelName = kernelName;
175 var statusPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, opencl_h.CL_SUCCESS());
176 MemorySegment kernelNamePtr = program.context.device.platform.openCL.arena.allocateFrom(kernelName);
177 kernel = opencl_h.clCreateKernel(program.program, kernelNamePtr, statusPtr);
178 int status = statusPtr.get(opencl_h.C_INT, 0);
179 if (status != opencl_h.CL_SUCCESS()) {
180 System.out.println("failed to create kernel " + status);
181 }
182 }
183
184 public void run(int range, Object... args) {
185 var bufPtr = program.context.device.platform.openCL.arena.allocate(opencl_h.cl_mem, args.length);
186 var statusPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, opencl_h.CL_SUCCESS());
187 int status;
188 var eventMax = args.length * 4 + 1;
189 int eventc = 0;
190 var eventsPtr = program.context.device.platform.openCL.arena.allocate(opencl_h.cl_event, eventMax);
191 boolean block = false;// true;
192 for (int i = 0; i < args.length; i++) {
193 if (args[i] instanceof MemorySegment memorySegment) {
194 MemorySegment clMem = opencl_h.clCreateBuffer(program.context.context,
195 CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
196 memorySegment.byteSize(),
197 memorySegment,
198 statusPtr);
199 status = statusPtr.get(opencl_h.C_INT, 0);
200 if (status != opencl_h.CL_SUCCESS()) {
201 System.out.println("failed to create memory buffer " + status);
202 }
203 bufPtr.set(opencl_h.cl_mem, i * opencl_h.cl_mem.byteSize(), clMem);
204 status = opencl_h.clEnqueueWriteBuffer(program.context.queue,
205 clMem,
206 block ? opencl_h.CL_TRUE() : opencl_h.CL_FALSE(), //block?
207 0,
208 memorySegment.byteSize(),
209 memorySegment,
210 block ? 0 : eventc,
211 block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
212 block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event)
213 );
214 if (status != opencl_h.CL_SUCCESS()) {
215 System.out.println("failed to enqueue write " + status);
216 }
217 if (!block) {
218 eventc++;
219 }
220 var clMemPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, clMem);
221
222 status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_POINTER.byteSize(), clMemPtr);
223 if (status != opencl_h.CL_SUCCESS()) {
224 System.out.println("failed to set arg " + status);
225 }
226 } else {
227 bufPtr.set(opencl_h.cl_mem, i * opencl_h.cl_mem.byteSize(), NULL);
228 switch (args[i]){
229 case Integer intArg->{
230 var intPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_INT, intArg);
231 status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_INT.byteSize(), intPtr);
232 if (status != opencl_h.CL_SUCCESS()) {
233 System.out.println("failed to set arg " + status);
234 }
235 }
236 case Float floatArg->{
237 var floatPtr = program.context.device.platform.openCL.arena.allocateFrom(opencl_h.C_FLOAT, floatArg);
238 status = opencl_h.clSetKernelArg(kernel, i, opencl_h.C_FLOAT.byteSize(), floatPtr);
239 if (status != opencl_h.CL_SUCCESS()) {
240 System.out.println("failed to set arg " + status);
241 }
242 }
243 default -> throw new IllegalStateException("Unexpected value: " + args[i]);
244 }
245 }
246 }
247
248 // We need to store x,y,z sizes so this is a kind of int3
249 var globalSizePtr = program.context.device.platform.openCL.arena.allocate(opencl_h.C_INT, 3);
250 globalSizePtr.set(opencl_h.C_INT, 0, range);
251 globalSizePtr.set(opencl_h.C_INT, 1*opencl_h.C_INT.byteSize(), 0);
252 globalSizePtr.set(opencl_h.C_INT, 2*opencl_h.C_INT.byteSize(), 0);
253 status = opencl_h.clEnqueueNDRangeKernel(
254 program.context.queue,
255 kernel,
256 1, // this must match the # of dims we are using in this case 1 of 3
257 NULL,
258 globalSizePtr,
259 NULL,
260 block ? 0 : eventc,
261 block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
262 block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event
263 )
264 );
265 if (status != opencl_h.CL_SUCCESS()) {
266 System.out.println("failed to enqueue NDRange " + status);
267 }
268
269 if (block) {
270 opencl_h.clFlush(program.context.queue);
271 } else {
272 eventc++;
273 status = opencl_h.clWaitForEvents(eventc, eventsPtr);
274 if (status != opencl_h.CL_SUCCESS()) {
275 System.out.println("failed to wait for ndrange events " + status);
276 }
277 }
278
279 for (int i = 0; i < args.length; i++) {
280 if (args[i] instanceof MemorySegment memorySegment) {
281 MemorySegment clMem = bufPtr.get(opencl_h.cl_mem, (long) i * opencl_h.cl_mem.byteSize());
282 status = opencl_h.clEnqueueReadBuffer(program.context.queue,
283 clMem,
284 block ? opencl_h.CL_TRUE() : opencl_h.CL_FALSE(),
285 0,
286 memorySegment.byteSize(),
287 memorySegment,
288 block ? 0 : eventc,
289 block ? NULL : ((eventc == 0) ? NULL : eventsPtr),
290 block ? NULL : eventsPtr.asSlice(eventc * opencl_h.cl_event.byteSize(), opencl_h.cl_event)// block?NULL:readEventPtr
291 );
292 if (status != opencl_h.CL_SUCCESS()) {
293 System.out.println("failed to enqueue read " + status);
294 }
295 if (!block) {
296 eventc++;
297 }
298 }
299 }
300 if (!block) {
301 status = opencl_h.clWaitForEvents(eventc, eventsPtr);
302 if (status != opencl_h.CL_SUCCESS()) {
303 System.out.println("failed to wait for events " + status);
304 }
305 }
306 for (int i = 0; i < args.length; i++) {
307 if (args[i] instanceof MemorySegment memorySegment) {
308 MemorySegment clMem = bufPtr.get(opencl_h.cl_mem, (long) i * opencl_h.cl_mem.byteSize());
309 status = opencl_h.clReleaseMemObject(clMem);
310 if (status != opencl_h.CL_SUCCESS()) {
311 System.out.println("failed to release memObject " + status);
312 }
313 }
314 }
315 }
316 }
317
318 public Kernel getKernel(String kernelName) {
319 return new Kernel(this, kernelName);
320 }
321 }
322
323 public Program buildProgram(String source) {
324 var program = new Program(this, source);
325 return program;
326 }
327 }
328
329 public Context createContext() {
330
331 var statusPtr = platform.openCL.arena.allocateFrom(opencl_h.C_INT, 0);
332 MemorySegment context;
333 var deviceIds = platform.openCL.arena.allocateFrom(opencl_h.C_POINTER, this.deviceId);
334 if ((context = opencl_h.clCreateContext(NULL, 1, deviceIds, NULL, NULL, statusPtr)) == NULL) {
335 int status = statusPtr.get(opencl_h.C_INT, 0);
336 System.out.println("Failed to get context ");
337 return null;
338 } else {
339 int status = statusPtr.get(opencl_h.C_INT, 0);
340 if (status != opencl_h.CL_SUCCESS()) {
341 System.out.println("failed to get context " + status);
342 }
343 return new Context(this, context);
344 }
345 }
346 }
347
348 int intPlatformInfo(int query) {
349 var value = 0;
350 if ((opencl_h.clGetPlatformInfo(platformId, query, opencl_h.C_INT.byteSize(), intValuePtr, NULL)) != opencl_h.CL_SUCCESS()) {
351 System.out.println("Failed to get query " + query);
352 } else {
353 value = intValuePtr.get(opencl_h.C_INT, 0);
354 }
355 return value;
356 }
357
358 String strPlatformInfo(int query) {
359 String value = null;
360 int status;
361 if ((status = opencl_h.clGetPlatformInfo(platformId, query, 2048, byte2048ValuePtr, intValuePtr)) != opencl_h.CL_SUCCESS()) {
362 System.err.println("Failed to get query " + query);
363 } else {
364 int len = intValuePtr.get(opencl_h.C_INT, 0);
365 byte[] bytes = byte2048ValuePtr.toArray(ValueLayout.JAVA_BYTE);
366 value = new String(bytes).substring(0, len - 1);
367 }
368 return value;
369 }
370
371 CLWrap openCL;
372 MemorySegment platformId;
373 List<Device> devices = new ArrayList<>();
374 final MemorySegment intValuePtr;
375 final MemorySegment byte2048ValuePtr;
376
377 String platformName() {
378 return strPlatformInfo(opencl_h.CL_PLATFORM_NAME());
379 }
380
381 String vendorName() {
382 return strPlatformInfo(opencl_h.CL_PLATFORM_VENDOR());
383 }
384
385 String version() {
386 return strPlatformInfo(opencl_h.CL_PLATFORM_VERSION());
387 }
388
389 public Platform(CLWrap openCL, MemorySegment platformId) {
390 this.openCL = openCL;
391 this.platformId = platformId;
392 this.intValuePtr = openCL.arena.allocateFrom(opencl_h.C_INT, 0);
393 this.byte2048ValuePtr = openCL.arena.allocate(opencl_h.C_CHAR, 2048);
394 var devicecPtr = openCL.arena.allocateFrom(opencl_h.C_INT, 0);
395 int status;
396 if ((status = opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), 0, NULL, devicecPtr)) != opencl_h.CL_SUCCESS()) {
397 System.err.println("Failed getting devicec for platform 0 ");
398 } else {
399 int devicec = devicecPtr.get(opencl_h.C_INT, 0);
400 // System.out.println("platform 0 has " + devicec + " device" + ((devicec > 1) ? "s" : ""));
401 var deviceIdsPtr = openCL.arena.allocate(opencl_h.C_POINTER, devicec);
402 if ((status = opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), devicec, deviceIdsPtr, devicecPtr)) != opencl_h.CL_SUCCESS()) {
403 System.err.println("Failed getting deviceids for platform 0 ");
404 } else {
405 // System.out.println("We have "+devicec+" device ids");
406 for (int i = 0; i < devicec; i++) {
407 devices.add(new Device(this, deviceIdsPtr.get(opencl_h.C_POINTER, i * opencl_h.C_POINTER.byteSize())));
408 }
409 }
410 }
411 }
412 }
413
414 List<Platform> platforms = new ArrayList<>();
415
416 Arena arena;
417
418 CLWrap(Arena arena) {
419 this.arena = arena;
420 var platformcPtr = arena.allocateFrom(opencl_h.C_INT, 0);
421
422 if ((opencl_h.clGetPlatformIDs(0, NULL, platformcPtr)) != opencl_h.CL_SUCCESS()) {
423 System.out.println("Failed to get opencl platforms");
424 } else {
425 int platformc = platformcPtr.get(opencl_h.C_INT, 0);
426 // System.out.println("There are "+platformc+" platforms");
427 var platformIdsPtr = arena.allocate(opencl_h.C_POINTER, platformc);
428 if ((opencl_h.clGetPlatformIDs(platformc, platformIdsPtr, platformcPtr)) != opencl_h.CL_SUCCESS()) {
429 System.out.println("Failed getting platform ids");
430 } else {
431 for (int i = 0; i < platformc; i++) {
432 // System.out.println("We should have the ids");
433 platforms.add(new Platform(this, platformIdsPtr.get(opencl_h.C_POINTER, i)));
434 }
435 }
436 }
437 }
438
439
440 public static void main(String[] args) throws IOException {
441 try (var arena = Arena.ofConfined()) {
442 CLWrap openCL = new CLWrap(arena);
443
444 Platform.Device[] selectedDevice = new Platform.Device[1];
445 openCL.platforms.forEach(platform -> {
446 System.out.println("Platform Name " + platform.platformName());
447 platform.devices.forEach(device -> {
448 System.out.println(" Compute Units " + device.computeUnits());
449 System.out.println(" Device Name " + device.deviceName());
450 System.out.println(" Built In Kernels " + device.builtInKernels());
451 selectedDevice[0] = device;
452 });
453 });
454 var context = selectedDevice[0].createContext();
455 var program = context.buildProgram("""
456 __kernel void squares(__global int* in,__global int* out ){
457 int gid = get_global_id(0);
458 out[gid] = in[gid]*in[gid];
459 }
460 """);
461 var kernel = program.getKernel("squares");
462 var in = arena.allocate(opencl_h.C_INT, 512);
463 var out = arena.allocate(opencl_h.C_INT, 512);
464 for (int i = 0; i < 512; i++) {
465 in.set(opencl_h.C_INT, (int) i * opencl_h.C_INT.byteSize(), i);
466 }
467 kernel.run(512, in, out);
468 for (int i = 0; i < 512; i++) {
469 System.out.println(i + " " + out.get(opencl_h.C_INT, (int) i * opencl_h.C_INT.byteSize()));
470 }
471 }
472 }
473 }