1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package wrap.opencl;
26
27
28 import opencl.opencl_h;
29 import optkl.ifacemapper.Buffer;
30 import optkl.ifacemapper.BufferState;
31 import optkl.ifacemapper.MappableIface;
32 import wrap.ArenaHolder;
33 import wrap.Wrap;
34
35 import java.lang.foreign.Arena;
36 import java.lang.foreign.MemorySegment;
37 import java.util.ArrayList;
38 import java.util.List;
39
40 import static java.lang.foreign.MemorySegment.NULL;
41 import static opencl.opencl_h.CL_DEVICE_BUILT_IN_KERNELS;
42 import static opencl.opencl_h.CL_DEVICE_MAX_COMPUTE_UNITS;
43 import static opencl.opencl_h.CL_DEVICE_NAME;
44 import static opencl.opencl_h.CL_DEVICE_TYPE_ALL;
45 import static opencl.opencl_h.CL_DEVICE_VENDOR;
46 import static opencl.opencl_h.CL_MEM_READ_WRITE;
47 import static opencl.opencl_h.CL_MEM_USE_HOST_PTR;
48 import static opencl.opencl_h.CL_PROGRAM_BUILD_LOG;
49 import static opencl.opencl_h.CL_QUEUE_PROFILING_ENABLE;
50 import static opencl.opencl_h.CL_SUCCESS;
51
52 // https://streamhpc.com/blog/2013-04-28/opencl-error-codes/
53 public class CLPlatform implements ArenaHolder {
54 public static List<CLPlatform> platforms(Arena arena) {
55 var arenaWrapper = ArenaHolder.wrap(arena);
56 List<CLPlatform> platforms = new ArrayList<>();
57 var platformc = arenaWrapper.intPtr(0);
58 if ((opencl_h.clGetPlatformIDs(0, NULL, platformc.ptr())) != CL_SUCCESS()) {
59 System.out.println("Failed to get opencl platforms");
60 } else {
61 var platformIds = arenaWrapper.ptrArr(platformc.get());
62 if ((opencl_h.clGetPlatformIDs(platformc.get(), platformIds.ptr(), NULL)) != CL_SUCCESS()) {
63 System.out.println("Failed getting platform ids");
64 } else {
65 for (int i = 0; i < platformc.get(); i++) {
66 platforms.add(new CLPlatform(arena, platformIds.get(i)));
67 }
68 }
69 }
70 return platforms;
71 }
72
73 public static class CLDevice implements ArenaHolder {
74 final CLPlatform platform;
75 final MemorySegment deviceId;
76
77 @Override
78 public Arena arena() {
79 return platform.arena();
80 }
81
82 int intDeviceInfo(int query) {
83 var value = intPtr(0);
84 if ((opencl_h.clGetDeviceInfo(deviceId, query, value.sizeof(), value.ptr(), NULL)) != CL_SUCCESS()) {
85 throw new RuntimeException("Failed to get query " + query);
86 }
87 return value.get();
88 }
89
90 String strDeviceInfo(int query) {
91 var value = cstr(2048);
92 if ((opencl_h.clGetDeviceInfo(deviceId, query, value.len(), value.ptr(), NULL)) != CL_SUCCESS()) {
93 throw new RuntimeException("Failed to get query " + query);
94 }
95 return value.get();
96 }
97
98 public int computeUnits() {
99 return intDeviceInfo(CL_DEVICE_MAX_COMPUTE_UNITS());
100 }
101
102 public String deviceName() {
103 return strDeviceInfo(CL_DEVICE_NAME());
104 }
105
106 public String deviceVendor() {
107 return strDeviceInfo(CL_DEVICE_VENDOR());
108 }
109
110 public String builtInKernels() {
111 return strDeviceInfo(CL_DEVICE_BUILT_IN_KERNELS());
112 }
113
114 CLDevice(CLPlatform platform, MemorySegment deviceId) {
115 this.platform = platform;
116 this.deviceId = deviceId;
117 }
118
119 public static class CLContext implements ArenaHolder {
120 CLDevice device;
121 MemorySegment context;
122 MemorySegment queue;
123
124 @Override
125 public Arena arena() {
126 return device.arena();
127 }
128
129 CLContext(CLDevice device, MemorySegment context) {
130 this.device = device;
131 this.context = context;
132 var status = device.platform.status;
133
134 var queue_props = CL_QUEUE_PROFILING_ENABLE();
135 if ((this.queue = opencl_h.clCreateCommandQueue(context, device.deviceId, queue_props, status.ptr())) == NULL) {
136 opencl_h.clReleaseContext(context);
137 } else {
138 if (!status.isOK()) {
139 opencl_h.clReleaseContext(context);
140 }
141 }
142 }
143
144 static public class CLProgram implements ArenaHolder {
145 CLContext context;
146 String source;
147 MemorySegment program;
148 String log;
149
150 @Override
151 public Arena arena() {
152 return context.arena();
153 }
154
155 CLProgram(CLContext context, String source) {
156 this.context = context;
157 this.source = source;
158 var sourceList = ptrArr(source);
159 var status = context.device.platform.status;
160 if ((program = opencl_h.clCreateProgramWithSource(context.context, 1, sourceList.ptr(), NULL, status.ptr())) == NULL) {
161 if (!status.isOK()) {
162 throw new RuntimeException("failed to createProgram " + status.get());
163 }
164 throw new RuntimeException("failed to createProgram");
165 } else {
166 var deviceIdList = ptrArr(context.device.deviceId);
167 if ((status.set(opencl_h.clBuildProgram(program, 1, deviceIdList.ptr(), NULL, NULL, NULL))) != CL_SUCCESS()) {
168 System.err.println("failed to build program " + status);
169 System.err.println("Source "+source);
170 // System.exit(1);
171 }
172 var logLen = longPtr(1L);
173 if ((status.set(opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, CL_PROGRAM_BUILD_LOG(), 0, NULL, logLen.ptr()))) != CL_SUCCESS()) {
174 System.err.println("failed to get log build " + status.get());
175 System.exit(1);
176 } else {
177 var logPtr = cstr(1 + logLen.get());
178 if ((status.set(opencl_h.clGetProgramBuildInfo(program, context.device.deviceId, opencl_h.CL_PROGRAM_BUILD_LOG(), logLen.get(), logPtr.ptr(), NULL))) != opencl_h.CL_SUCCESS()) {
179 System.out.println("clGetBuildInfo (getting log) failed ");
180 } else {
181 log = logPtr.get();
182 System.out.println("log\n" +log);
183 }
184 }
185 }
186 }
187
188 public static class CLKernel implements ArenaHolder {
189 CLProgram program;
190 MemorySegment kernel;
191 String kernelName;
192
193 @Override
194 public Arena arena() {
195 return program.arena();
196 }
197
198 public CLKernel(CLProgram program, String kernelName) {
199 this.program = program;
200 this.kernelName = kernelName;
201 var kernelNameCStr = this.cstr(kernelName);
202 var status = program.context.device.platform.status;
203 kernel = opencl_h.clCreateKernel(program.program, kernelNameCStr.ptr(), status.ptr());
204 if (!status.isOK()) {
205 System.out.println("failed to create kernel '"+kernelName+"'" + status);
206 }
207 }
208
209
210 public void run(CLWrapComputeContext clWrapComputeContext, int range, Object... args) {
211 var status = CLStatusPtr.of(arena());
212 for (int i = 0; i < args.length; i++) {
213 if (args[i] instanceof CLWrapComputeContext.MemorySegmentState memorySegmentState) {
214 if (memorySegmentState.clMemPtr == null) {
215 memorySegmentState.clMemPtr = CLWrapComputeContext.ClMemPtr.of(arena(), opencl_h.clCreateBuffer(program.context.context,
216 CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
217 memorySegmentState.memorySegment.byteSize(),
218 memorySegmentState.memorySegment,
219 status.ptr()));
220 if (!status.isOK()) {
221 throw new RuntimeException("failed to create memory buffer for arg["+i+" " + status.get());
222 }
223 }
224 if (memorySegmentState.copyToDevice) {
225 status.set(opencl_h.clEnqueueWriteBuffer(program.context.queue,
226 memorySegmentState.clMemPtr.get(),
227 clWrapComputeContext.blockInt(),
228 0,
229 memorySegmentState.memorySegment.byteSize(),
230 memorySegmentState.memorySegment,
231 clWrapComputeContext.eventc(),
232 clWrapComputeContext.eventsPtr(),
233 clWrapComputeContext.nextEventPtrSlot()
234 ));
235 if (!status.isOK()) {
236 System.err.println("failed to enqueue write for arg["+i+" " + status);
237 System.exit(1);
238 }
239 }
240
241 status.set(opencl_h.clSetKernelArg(kernel, i, memorySegmentState.clMemPtr.sizeof(), memorySegmentState.clMemPtr.ptr()));
242 if (!status.isOK()) {
243 System.err.println("failed to set arg["+i+" " + status);
244 System.exit(1);
245 }
246 } else if (args[i] instanceof Buffer buffer) {
247 // System.out.println("Arg "+i+" is a buffer so checking if we need to write");
248 BufferState bufferState = BufferState.of(buffer);
249
250 //System.out.println("Before possible write"+ bufferState);
251 MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
252
253 CLWrapComputeContext.ClMemPtr clmem = clWrapComputeContext.clMemMap.computeIfAbsent(memorySegment, k ->
254 CLWrapComputeContext.ClMemPtr.of(arena(), opencl_h.clCreateBuffer(program.context.context,
255 CL_MEM_USE_HOST_PTR() | CL_MEM_READ_WRITE(),
256 memorySegment.byteSize(),
257 memorySegment,
258 status.ptr()))
259 );
260 if (bufferState.getState()==BufferState.HOST_OWNED) {
261
262 //System.out.println("arg " + args[i] + " isHostDirty copying in");
263 status.set(opencl_h.clEnqueueWriteBuffer(program.context.queue,
264 clmem.get(),
265 clWrapComputeContext.blockInt(),
266 0,
267 memorySegment.byteSize(),
268 memorySegment,
269 clWrapComputeContext.eventc(),
270 clWrapComputeContext.eventsPtr(),
271 clWrapComputeContext.nextEventPtrSlot()
272 ));
273 if (!status.isOK()) {
274 System.err.println("failed to enqueue write for arg["+i+" " + status);
275 System.exit(1);
276 }
277 } else {
278
279 // System.out.println("arg "+args[i]+" is not HostDirty not copying in");
280 }
281 // System.out.println("After possible write "+ bufferState);
282 status.set(opencl_h.clSetKernelArg(kernel, i, clmem.sizeof(), clmem.ptr()));
283 if (!status.isOK()) {
284 System.err.println("failed to set arg["+i+"]" + status);
285 System.exit(1);
286 }
287
288 } else {
289 Wrap.Ptr ptr = switch (args[i]) {
290 case Integer intArg -> intPtr(intArg);
291 case Float floatArg -> floatPtr(floatArg);
292 case Double doubleArg -> doublePtr(doubleArg);
293 case Long longArg -> longPtr(longArg);
294 case Short shortArg -> shortPtr(shortArg);
295 default -> throw new IllegalStateException("Unexpected value: " + args[i]);
296 };
297 status.set(opencl_h.clSetKernelArg(kernel, i, ptr.sizeof(), ptr.ptr()));
298 if (!status.isOK()) {
299 System.err.println("failed to set arg["+i+"] " + status);
300
301 System.exit(1);
302 }
303
304 }
305 }
306
307 // We need to store x,y,z sizes so this is a kind of int3
308 var globalSize = this.ofInts(range, 0, 0);
309 status.set(opencl_h.clEnqueueNDRangeKernel(
310 program.context.queue,
311 kernel,
312 1, // this must match the # of dims we are using in this case 1 of 3
313 NULL,
314 globalSize.ptr(),
315 NULL,
316 clWrapComputeContext.eventc(),
317 clWrapComputeContext.eventsPtr(),
318 clWrapComputeContext.nextEventPtrSlot()
319 )
320 );
321 if (!status.isOK()) {
322 System.out.println("failed to enqueue NDRange " + status);
323 }
324
325 if (clWrapComputeContext.alwaysBlock) {
326 opencl_h.clFlush(program.context.queue);
327 }
328
329 for (int i = 0; i < args.length; i++) {
330 if (args[i] instanceof CLWrapComputeContext.MemorySegmentState memorySegmentState) {
331 if (memorySegmentState.copyFromDevice) {
332 status.set(opencl_h.clEnqueueReadBuffer(program.context.queue,
333 memorySegmentState.clMemPtr.get(),
334 clWrapComputeContext.blockInt(),
335 0,
336 memorySegmentState.memorySegment.byteSize(),
337 memorySegmentState.memorySegment,
338 clWrapComputeContext.eventc(),
339 clWrapComputeContext.eventsPtr(),
340 clWrapComputeContext.nextEventPtrSlot()
341 ));
342 if (!status.isOK()) {
343 System.out.println("failed to enqueue read " + status);
344 }
345 }
346 } else if (args[i] instanceof Buffer buffer) {
347 // System.out.println("Arg "+i+" is a buffer so checking if we need to read");
348 BufferState bufferState = BufferState.of(buffer);
349 MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
350 CLWrapComputeContext.ClMemPtr clmem = clWrapComputeContext.clMemMap.get(memorySegment);
351 // System.out.println("Before possible read "+ bufferState);
352 if (bufferState.getState() == BufferState.HOST_OWNED) {
353 // System.out.println("arg " + args[i] + " isDeviceDirty copying out");
354 status.set(opencl_h.clEnqueueReadBuffer(program.context.queue,
355 clmem.get(),
356 clWrapComputeContext.blockInt(),
357 0,
358 memorySegment.byteSize(),
359 memorySegment,
360 clWrapComputeContext.eventc(),
361 clWrapComputeContext.eventsPtr(),
362 clWrapComputeContext.nextEventPtrSlot()
363 ));
364 if (!status.isOK()) {
365 System.out.println("failed to enqueue read " + status);
366 }
367 } else {
368 // System.out.println("arg "+args[i]+" isnot DeviceDirty not copying out");
369 }
370
371 }
372 }
373 // if (!computeContext.alwaysBlock) {
374 clWrapComputeContext.waitForEvents();
375 // }
376 }
377 }
378
379 public CLKernel getKernel(String kernelName) {
380 return new CLKernel(this, kernelName);
381 }
382 }
383
384 public CLProgram buildProgram(String source) {
385 var program = new CLProgram(this, source);
386 return program;
387 }
388 }
389
390 public CLContext createContext() {
391 var status = platform.status;
392 MemorySegment context;
393 var deviceIds = ptrArr(this.deviceId);
394 if ((context = opencl_h.clCreateContext(NULL, 1, deviceIds.ptr(), NULL, NULL, status.ptr())) == NULL) {
395 System.out.println("Failed to get context ");
396 return null;
397 } else {
398 if (!status.isOK()) {
399 System.out.println("failed to get context " + status);
400 }
401 return new CLContext(this, context);
402 }
403 }
404 }
405
406 int intPlatformInfo(int query) {
407 var value = intPtr(0);
408 if ((opencl_h.clGetPlatformInfo(platformId, query, value.sizeof(), value.ptr(), NULL)) != opencl_h.CL_SUCCESS()) {
409 throw new RuntimeException("Failed to get query " + query);
410 }
411 return value.get();
412 }
413
414 String strPlatformInfo(int query) {
415
416 var value = cstr(2048);
417 int status;
418 if ((status = opencl_h.clGetPlatformInfo(platformId, query, value.len(), value.ptr(), NULL)) != opencl_h.CL_SUCCESS()) {
419 throw new RuntimeException("Failed to get query " + query);
420 }
421 return value.get();
422 }
423
424 private Arena secretarena;
425 MemorySegment platformId;
426 public List<CLDevice> devices = new ArrayList<>();
427 final CLStatusPtr status;
428
429 public String platformName() {
430 return strPlatformInfo(opencl_h.CL_PLATFORM_NAME());
431 }
432
433 String vendorName() {
434 return strPlatformInfo(opencl_h.CL_PLATFORM_VENDOR());
435 }
436
437 String version() {
438 return strPlatformInfo(opencl_h.CL_PLATFORM_VERSION());
439 }
440
441 @Override
442 public Arena arena() {
443 return secretarena;
444 }
445
446 public CLPlatform(Arena arena, MemorySegment platformId) {
447 this.secretarena = arena;
448 this.platformId = platformId;
449 this.status = CLStatusPtr.of(arena());
450 var devicec = intPtr(0);
451 if ((status.set(opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), 0, NULL, devicec.ptr()))) != opencl_h.CL_SUCCESS()) {
452 System.err.println("Failed getting devicec for platform 0 ");
453 } else {
454 // System.out.println("platform 0 has " + devicec + " device" + ((devicec > 1) ? "s" : ""));
455 var deviceIdList = ptrArr(devicec.get());
456 if ((status.set(opencl_h.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_ALL(), devicec.get(), deviceIdList.ptr(), devicec.ptr()))) != opencl_h.CL_SUCCESS()) {
457 System.err.println("Failed getting deviceids for platform 0 ");
458 } else {
459 // System.out.println("We have "+devicec+" device ids");
460 for (int i = 0; i < devicec.get(); i++) {
461 devices.add(new CLDevice(this, deviceIdList.get(i)));
462 }
463 }
464 }
465 }
466 }