1 import java.io.File;
2 import java.io.FileInputStream;
3 import java.util.List;
4 import java.util.Map;
5 import java.util.HashMap;
6 import java.util.ArrayList;
7 import java.util.Deque;
8 import java.util.ArrayDeque;
9 import java.util.Set;
10 import java.util.HashSet;
11 import java.util.Iterator;
12 import java.util.Arrays;
13 import java.util.Queue;
14 import java.util.Optional;
15 import java.util.stream.Stream;
16 import java.util.function.Consumer;
17 import java.util.function.Function;
18 import java.util.function.Supplier;
19 import java.time.*;
20 import java.nio.file.Files;
21 import java.nio.file.Paths;
22 import java.io.IOException;
23 import java.io.BufferedReader;
24 import java.io.FileWriter;
25 import java.io.FileReader;
26 import java.lang.reflect.Method;
27 import java.lang.ref.Cleaner;
28 import java.lang.foreign.MemorySegment;
29 import java.lang.foreign.ValueLayout;
30 import java.lang.foreign.AddressLayout;
31 import java.lang.foreign.Arena;
32 import java.lang.foreign.MemorySegment.Scope;
33 import static java.lang.foreign.ValueLayout.*;
34 import jdk.incubator.vector.VectorSpecies;
35 import jdk.incubator.vector.FloatVector;
36 import static oneapi.levelzero.ze_api_h.*;
37 import oneapi.levelzero.ze_api_h;
38 import oneapi.levelzero.ze_context_desc_t;
39 import oneapi.levelzero.ze_kernel_desc_t;
40 import oneapi.levelzero.ze_command_queue_desc_t;
41 import oneapi.levelzero.ze_command_list_desc_t;
42 import oneapi.levelzero.ze_command_queue_group_properties_t;
43 import oneapi.levelzero.ze_event_pool_desc_t;
44 import oneapi.levelzero.ze_event_desc_t;
45 import oneapi.levelzero.ze_fence_desc_t;
46 import oneapi.levelzero.ze_module_desc_t;
47 import oneapi.levelzero.ze_group_count_t;
48 import oneapi.levelzero.ze_host_mem_alloc_desc_t;
49 import oneapi.levelzero.ze_device_mem_alloc_desc_t;
50 import oneapi.levelzero.ze_device_properties_t;
51 import oneapi.levelzero.ze_device_compute_properties_t;
52 import oneapi.levelzero.ze_driver_properties_t;
53 import oneapi.levelzero.ze_driver_extension_properties_t;
54 import org.json.JSONArray;
55 import org.json.JSONObject;
56
57 import java.util.Random;
58
59 public class LevelZero {
60 public static final AddressLayout driver_handle_t = AddressLayout.ADDRESS;
61 private final Arena arena;
62 private final MemorySegment driverHandle;
63 private final MemorySegment contextHandle;
64 private final MemorySegment deviceHandle;
65 private final MemorySegment queueHandle;
66 private final MemorySegment eventPoolDescription;
67 private final String homeDir = System.getProperty("user.home");
68 private final String cacheDir = homeDir + "/.triton/cache/";
69 private final String addKernelCache = "7961f2e8b433c656051d8638d6a3bb65f43f6cb885525c05d611100dd905aa31";
70 private final String softmaxKernelCache = "f0c32acd1173759227ef8e0e8d197c94493b90ebf8d1fc254399ffac6b527d6a";
71 private final String matmulKernelCache = "07e17c2833c9c9efea8ccd782af1c3ee05dcac3efb2cb75f2f8a6eecffe381ef";
72 private final static VectorSpecies<Float> SPECIES = FloatVector.SPECIES_256;
73 private Double timeElapsedForRun, timeElapsedForRerun;
74
75 static {
76 System.loadLibrary("ze_loader");
77 }
78
79 static void debug(String format, Object... args) {
80 System.out.printf(format + "%n", args);
81 }
82
83 private static void check(int result) {
84 if (result != ZE_RESULT_SUCCESS()) {
85 throw new RuntimeException(String.format("Call failed: 0x%x (%d)", result, result));
86 }
87 }
88
89 MemorySegment contextHandle() {
90 return contextHandle;
91 }
92
93 MemorySegment deviceHandle() {
94 return deviceHandle;
95 }
96
97 public LevelZero() {
98 arena = Arena.ofShared();
99
100 // get driver
101 check(zeInit(ZE_INIT_FLAG_GPU_ONLY()));
102 MemorySegment driverCount = arena.allocate(Integer.BYTES);
103 check(zeDriverGet(driverCount, MemorySegment.NULL));
104 debug("driverCount = %d", driverCount.get(JAVA_INT, 0));
105 MemorySegment driverHandles = arena.allocate(driverCount.get(JAVA_INT, 0) * driver_handle_t.byteSize(), 8);
106 check(zeDriverGet(driverCount, driverHandles));
107 driverHandle = driverHandles.get(ADDRESS, 0);
108
109 // create context
110 MemorySegment pContextDesc = arena.allocate(ze_context_desc_t.layout());
111 ze_context_desc_t.stype(pContextDesc, ZE_STRUCTURE_TYPE_CONTEXT_DESC());
112 MemorySegment pContextHandle = arena.allocate(ze_context_handle_t);
113 check(zeContextCreate(driverHandle, pContextDesc, pContextHandle));
114 contextHandle = pContextHandle.get(ADDRESS, 0);
115
116 // get device
117 MemorySegment pDeviceCount = arena.allocate(Integer.BYTES);
118 check(zeDeviceGet(driverHandle, pDeviceCount, MemorySegment.NULL));
119 int deviceCount = pDeviceCount.get(JAVA_INT, 0);
120 assert deviceCount > 0;
121 debug("deviceCount = %d", deviceCount);
122 MemorySegment deviceHandles = arena.allocate(deviceCount * ze_device_handle_t.byteSize(), 8);
123 check(zeDeviceGet(driverHandle, pDeviceCount, deviceHandles));
124 for (int i = 0; i < deviceCount; i++) {
125 debug("device #%d: %s", i, deviceHandles.get(ze_device_handle_t, i * ze_device_handle_t.byteSize()));
126 }
127 deviceHandle = deviceHandles.get(ze_device_handle_t, 0 * ze_device_handle_t.byteSize());
128 MemorySegment pDeviceProperties = arena.allocate(ze_device_properties_t.layout());
129 ze_device_properties_t.stype(pDeviceProperties, ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES());
130 check(zeDeviceGetProperties(deviceHandle, pDeviceProperties));
131 debug("deviceProperties:\n\ttype = %d\n\tvendorId = %d\n\tmaxMemAllocSize = %d\n\tdeviceId = %d\n\tcoreClockRate = %d",
132 ze_device_properties_t.type(pDeviceProperties),
133 ze_device_properties_t.vendorId(pDeviceProperties),
134 ze_device_properties_t.maxMemAllocSize(pDeviceProperties),
135 ze_device_properties_t.deviceId(pDeviceProperties),
136 ze_device_properties_t.coreClockRate(pDeviceProperties));
137
138 MemorySegment pDeviceComputeProperties = arena.allocate(ze_device_compute_properties_t.layout());
139 ze_device_compute_properties_t.stype(pDeviceComputeProperties, ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES());
140 check(zeDeviceGetComputeProperties(deviceHandle, pDeviceComputeProperties));
141 debug("deviceProperties:\n\tshared = %d\n\tmaxTotalGroupSize = %d",
142 ze_device_compute_properties_t.maxSharedLocalMemory(pDeviceComputeProperties),
143 ze_device_compute_properties_t.maxTotalGroupSize(pDeviceComputeProperties));
144
145 // create queue
146 MemorySegment pNumQueueGroups = arena.allocate(JAVA_INT, 1);
147 check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, MemorySegment.NULL));
148 debug("#Queue Groups: %d", pNumQueueGroups.get(JAVA_INT, 0));
149 MemorySegment pGroupProperties = arena.allocate(ze_command_queue_group_properties_t.layout(), pNumQueueGroups.get(JAVA_INT, 0));
150 check(zeDeviceGetCommandQueueGroupProperties(deviceHandle, pNumQueueGroups, pGroupProperties));
151
152 MemorySegment pQueueDesc = arena.allocate(ze_command_queue_desc_t.layout());
153 ze_command_queue_desc_t.stype(pQueueDesc, ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC());
154 ze_command_queue_desc_t.index(pQueueDesc, 0);
155 ze_command_queue_desc_t.mode(pQueueDesc, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS());
156 ze_command_queue_desc_t.ordinal(pQueueDesc, 0);
157 MemorySegment pQueueHandle = arena.allocate(ze_command_queue_handle_t);
158 check(zeCommandQueueCreate(contextHandle, deviceHandle, pQueueDesc, pQueueHandle));
159 queueHandle = pQueueHandle.get(ADDRESS, 0);
160
161 eventPoolDescription = arena.allocate(ze_event_pool_desc_t.layout());
162 ze_event_pool_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_EVENT_POOL_DESC());
163 ze_event_pool_desc_t.count(eventPoolDescription, 20);
164 ze_event_pool_desc_t.flags(eventPoolDescription, ZE_EVENT_POOL_FLAG_HOST_VISIBLE());
165
166 timeElapsedForRun = timeElapsedForRerun = 0.0;
167 }
168
169 public void clear() {
170 check(zeCommandQueueDestroy(queueHandle));
171 check(zeContextDestroy(contextHandle));
172 }
173
174 public void test(String testName) {
175 Object[] args = {};
176 Random rand = new Random();
177 if (testName.equals("add")) {
178 String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json";
179 String moduleName = cacheDir + addKernelCache + "/add_kernel.spv";
180
181 int BLOCK_SIZE = 64;
182 int elementSize = 4096;
183 int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
184
185 JSONObject jsonObject = loadJson(jsonFileName);
186 String kernelName = jsonObject.getString("name");
187 int threads_per_warp = jsonObject.getInt("threads_per_warp");
188 int num_warps = jsonObject.getInt("num_warps");
189 int shared = jsonObject.getInt("shared");
190
191 float[] input1 = new float[elementSize];
192 float[] input2 = new float[elementSize];
193 float[] output = new float[elementSize];
194 for (int i = 0; i < elementSize; i++) {
195 input1[i] = rand.nextFloat();
196 input2[i] = rand.nextFloat();
197 }
198 args = new Object[] {input1, input2, output, elementSize};
199 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
200
201 float[] expected = Test.add(input1, input2, elementSize);
202 Test.check(expected, output);
203 } else if (testName.equals("softmax")) {
204 String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json";
205 String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv";
206
207 JSONObject jsonObject = loadJson(jsonFileName);
208 String kernelName = jsonObject.getString("name");
209 int threads_per_warp = jsonObject.getInt("threads_per_warp");
210 int num_warps = jsonObject.getInt("num_warps");
211 int shared = jsonObject.getInt("shared");
212
213 int elementSizeX = 4096, elementSizeY = 64;
214 int gridSize = elementSizeX;
215 float[] input = new float[elementSizeX * elementSizeY];
216 float[] output = new float[elementSizeX * elementSizeY];
217 byte[] sharedMemory = new byte[shared]; // use for storing temporary value of max element and sum of exp
218 for (int i = 0; i < elementSizeX * elementSizeY; i++) {
219 input[i] = rand.nextFloat();
220 }
221 args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory};
222 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
223
224 float[] expected = Test.softmax(input, elementSizeX, elementSizeY);
225 Test.check(expected, output);
226 } else if (testName.equals("matmul")) {
227 String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json";
228 String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv";
229
230 JSONObject jsonObject = loadJson(jsonFileName);
231 String kernelName = jsonObject.getString("name");
232 int threads_per_warp = jsonObject.getInt("threads_per_warp");
233 int num_warps = jsonObject.getInt("num_warps");
234 int shared = jsonObject.getInt("shared");
235
236 int M = 1024, N = 1024, K = 1024;
237 int BLOCK_SIZE_M = 32, BLOCK_SIZE_N = 64;
238 int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N);
239 float[] a = new float[M * K];
240 float[] b = new float[K * N];
241 float[] c = new float[M * N];
242 byte[] sharedMemory = new byte[shared];
243
244 for (int i = 0; i < M * K; i++) {
245 a[i] = rand.nextFloat();
246 }
247 for (int i = 0; i < K * N; i++) {
248 b[i] = rand.nextFloat();
249 }
250 args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory};
251 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
252
253 float[] expected = Test.matmul(a, b, M, N, K);
254 Test.check(expected, c);
255 } else {
256 throw new RuntimeException("Unsupported test: " + testName);
257 }
258 }
259
260 public void run(String kernelName, String fileName, Object[] args, int threads_per_warp, int num_warps, int shared, int gridSize) {
261 debug("=========== run %s ===========", kernelName);
262 MemorySegment spirvBinary = loadModule(fileName);
263 List<Arg> kernelArgs = collectArgs(args);
264 int[] globalSizes = new int[] {gridSize * threads_per_warp * num_warps, 1, 1};
265 int[] localSizes = new int[] {threads_per_warp * num_warps, 1, 1};
266 KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes);
267 debug("geometry = %s", geometry);
268 MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, shared, false);
269 executeCommandList(commandListHandle);
270 check(zeCommandQueueSynchronize(queueHandle, -1L));
271 for (int i = 0; i < kernelArgs.size(); i++) {
272 copyArgToHost(kernelArgs.get(i), contextHandle);
273 }
274
275 for (int i = 0; i < kernelArgs.size(); i++) {
276 Arg arg = kernelArgs.get(i);
277 MemorySegment dataSegment = arg.dataSegment();
278 if (dataSegment != null) {
279 check(zeMemFree(contextHandle, dataSegment));
280 }
281 }
282 check(zeCommandListDestroy(commandListHandle));
283 }
284
285 public void runRefMatmul(String kernelName, String fileName, Object[] args, int size) {
286 debug("=========== run %s ===========", kernelName);
287 MemorySegment spirvBinary = loadModule(fileName);
288 List<Arg> kernelArgs = collectArgs(args);
289 int[] globalSizes = new int[] {size, size, 1};
290 int[] localSizes = new int[] {512, 1, 1};
291 KernelGeometry geometry = new KernelGeometry(globalSizes, localSizes);
292 debug("geometry = %s", geometry);
293 MemorySegment commandListHandle = createCommandList(spirvBinary, kernelName, geometry, kernelArgs, 0, true);
294 executeCommandList(commandListHandle);
295 check(zeCommandQueueSynchronize(queueHandle, -1L));
296 for (int i = 0; i < kernelArgs.size(); i++) {
297 copyArgToHost(kernelArgs.get(i), contextHandle);
298 }
299
300 for (int i = 0; i < kernelArgs.size(); i++) {
301 Arg arg = kernelArgs.get(i);
302 MemorySegment dataSegment = arg.dataSegment();
303 if (dataSegment != null) {
304 check(zeMemFree(contextHandle, dataSegment));
305 }
306 }
307 check(zeCommandListDestroy(commandListHandle));
308 }
309
310 private List<Arg> collectArgs(Object[] values) {
311 List<Arg> args = new ArrayList<>();
312 for (int i = 0; i < values.length; i++) {
313 args.add(Arg.createArg(this, "arg" + i, values[i]));
314 }
315 debug("args = %s", args);
316 return args;
317 }
318
319 MemorySegment loadModule(String fileName) {
320 byte[] data = readBytes(fileName);
321 MemorySegment segment = arena.allocate(data.length);
322 segment.copyFrom(MemorySegment.ofArray(data));
323 return segment;
324 }
325
326 byte[] readBytes(String filename) {
327 File file = new File(filename);
328 try (FileInputStream fis = new FileInputStream(file)) {
329 byte[] data = new byte[(int) file.length()];
330 fis.read(data);
331 return data;
332 } catch (IOException e) {
333 throw new RuntimeException(e);
334 }
335 }
336
337 void provisionArg(Arg arg) {
338 if (arg.cls() == byte[].class) {
339 byte[] array = (byte[])arg.value();
340 int segmentSize = array.length;
341 arg.setDataSegment(allocateSharedSegment(segmentSize));
342 arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
343 arg.setSize(8);
344 arg.setNeedsCleanup(true);
345 }
346 else if (arg.cls() == short[].class) {
347 short[] array = (short[])arg.value();
348 int segmentSize = array.length * Short.BYTES;
349 arg.setDataSegment(allocateSharedSegment(segmentSize));
350 arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
351 arg.setSize(8);
352 arg.setNeedsCleanup(true);
353 }
354 else if (arg.cls() == int[].class) {
355 int[] array = (int[])arg.value();
356 int segmentSize = array.length * Integer.BYTES;
357 arg.setDataSegment(allocateSharedSegment(segmentSize));
358 arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
359 arg.setSize(8);
360 arg.setNeedsCleanup(true);
361 }
362 else if (arg.cls() == float[].class) {
363 float[] array = (float[])arg.value();
364 int segmentSize = array.length * Float.BYTES;
365 arg.setDataSegment(allocateSharedSegment(segmentSize));
366 arg.dataSegment().copyFrom(MemorySegment.ofArray(array));
367 arg.setSize(8);
368 arg.setNeedsCleanup(true);
369 }
370 else if (VectorSpecies.class.isAssignableFrom(arg.cls())) {
371 arg.setSize(4);
372 }
373 else if (arg.cls() == Short.class) {
374 arg.setSize(2);
375 }
376 else if (arg.cls() == Integer.class || arg.cls() == Float.class || arg.cls() == Boolean.class) {
377 arg.setSize(4);
378 }
379 else if (arg.cls() == Long.class) {
380 arg.setSize(8);
381 }
382 else if (arg.cls() == GPU.Index.class) {
383 MemorySegment pBuffer = arena.allocate(ADDRESS);
384 arg.setDataSegment(allocateSharedSegment(24));
385 arg.setSize(24);
386 }
387 else throw new RuntimeException("unsupported type: " + arg.cls());
388 }
389
390 void copyArgToHost(Arg arg, MemorySegment contextHandle) {
391 if (arg.cls() == short[].class) {
392 short[] array = (short[])arg.value();
393 MemorySegment arraySegment = MemorySegment.ofArray(array);
394 arraySegment.copyFrom(arg.dataSegment());
395 }
396 else if (arg.cls() == int[].class) {
397 int[] array = (int[])arg.value();
398 MemorySegment arraySegment = MemorySegment.ofArray(array);
399 arraySegment.copyFrom(arg.dataSegment());
400 }
401 else if (arg.cls() == float[].class) {
402 float[] array = (float[])arg.value();
403 MemorySegment arraySegment = MemorySegment.ofArray(array);
404 arraySegment.copyFrom(arg.dataSegment());
405 }
406 // else nothing to do
407 }
408
409 private MemorySegment createCommandList(MemorySegment spirvModule, String kernelName, KernelGeometry geometry, List<Arg> args, int shared, boolean suggested) {
410 Arena arena = Arena.ofShared();
411 MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t);
412 MemorySegment commandListDesc = arena.allocate(ze_command_list_desc_t.layout());
413 ze_command_list_desc_t.stype(eventPoolDescription, ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC());
414 ze_command_list_desc_t.commandQueueGroupOrdinal(commandListDesc, 0);
415 MemorySegment moduleHandle = createModule(kernelName, spirvModule);
416 check(zeCommandListCreate(contextHandle, deviceHandle, commandListDesc, pCommandListHandle));
417 MemorySegment commandListHandle = pCommandListHandle.get(ADDRESS, 0);
418 MemorySegment kernelHandle = createKernel(moduleHandle, kernelName, geometry, suggested);
419 for (int i = 0; i < args.size(); i++) {
420 Arg arg = args.get(i);
421 setKernelArg(arg, i, commandListHandle, kernelHandle, (shared != 0) && (i == args.size() - 1));
422 }
423 MemorySegment groupCount = arena.allocate(ze_group_count_t.layout());
424 ze_group_count_t.groupCountX(groupCount, (geometry.globalSizes()[0] + geometry.localSizes()[0] - 1) / geometry.localSizes()[0]);
425 ze_group_count_t.groupCountY(groupCount, (geometry.globalSizes()[1] + geometry.localSizes()[1] - 1) / geometry.localSizes()[1]);
426 ze_group_count_t.groupCountZ(groupCount, (geometry.globalSizes()[2] + geometry.localSizes()[2] - 1) / geometry.localSizes()[2]);
427 MemorySegment pKernelWaitHandles = MemorySegment.NULL;
428 check(zeCommandListAppendLaunchKernel(commandListHandle, kernelHandle, groupCount, MemorySegment.NULL, 0, pKernelWaitHandles));
429 check(zeCommandListClose(commandListHandle));
430 return commandListHandle;
431 }
432
433 private MemorySegment executeCommandList(MemorySegment commandListHandle) {
434 MemorySegment fenceDesc = arena.allocate(ze_fence_desc_t.layout());
435 ze_module_desc_t.stype(fenceDesc, ZE_STRUCTURE_TYPE_FENCE_DESC());
436 ze_fence_desc_t.flags(fenceDesc, ZE_FENCE_FLAG_SIGNALED());
437 MemorySegment pFenceHandle = arena.allocate(ze_fence_handle_t);
438 check(zeFenceCreate(queueHandle, fenceDesc, pFenceHandle));
439 MemorySegment fenceHandle = pFenceHandle.get(ADDRESS, 0);
440 MemorySegment pCommandListHandle = arena.allocate(ze_command_list_handle_t);
441 pCommandListHandle.set(ADDRESS, 0, commandListHandle);
442 Instant start = Instant.now();
443 check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle));
444 check(zeCommandQueueSynchronize(queueHandle, -1L));
445 Instant finish = Instant.now();
446 Double timeElapsed = Duration.between(start, finish).toNanos() * 1e-6;
447 timeElapsedForRun += timeElapsed;
448 debug("time: %f %f\n", timeElapsed, timeElapsedForRun);
449
450 start = Instant.now();
451 check(zeCommandQueueExecuteCommandLists(queueHandle, 1, pCommandListHandle, fenceHandle));
452 check(zeCommandQueueSynchronize(queueHandle, -1L));
453 finish = Instant.now();
454 timeElapsed = Duration.between(start, finish).toNanos() * 1e-6;
455 timeElapsedForRerun += timeElapsed;
456 debug("time for rerun: %f %f\n", timeElapsed, timeElapsedForRerun);
457 return fenceHandle;
458 }
459
460 private MemorySegment createKernel(MemorySegment moduleHandle, String kernelNameString, KernelGeometry geometry, boolean suggested) {
461 MemorySegment kernelDesc = arena.allocate(ze_kernel_desc_t.layout());
462 MemorySegment kernelName = arena.allocateFrom(kernelNameString);
463 ze_kernel_desc_t.stype(kernelDesc, ZE_STRUCTURE_TYPE_KERNEL_DESC());
464 ze_kernel_desc_t.pKernelName(kernelDesc, kernelName);
465 debug("name = %s", kernelNameString);
466 MemorySegment pKernelHandle = arena.allocate(ze_kernel_handle_t);
467 check(zeKernelCreate(moduleHandle, kernelDesc, pKernelHandle));
468 int[] globalSizes = geometry.globalSizes();
469 int[] localSizes = geometry.localSizes();
470 MemorySegment kernelHandle = pKernelHandle.get(ADDRESS, 0);
471 if (suggested) {
472 MemorySegment pGroupSizeX = arena.allocate(JAVA_INT, localSizes[0]);
473 MemorySegment pGroupSizeY = arena.allocate(JAVA_INT, localSizes[1]);
474 MemorySegment pGroupSizeZ = arena.allocate(JAVA_INT, localSizes[2]);
475 check(zeKernelSuggestGroupSize(kernelHandle, globalSizes[0], globalSizes[1], globalSizes[2], pGroupSizeX, pGroupSizeY, pGroupSizeZ));
476 geometry.localSizes()[0] = pGroupSizeX.get(JAVA_INT, 0);
477 geometry.localSizes()[1] = pGroupSizeY.get(JAVA_INT, 0);
478 geometry.localSizes()[2] = pGroupSizeZ.get(JAVA_INT, 0);
479 debug("use suggested group size", geometry.toString());
480 check(zeKernelSetGroupSize(kernelHandle, pGroupSizeX.get(JAVA_INT, 0), pGroupSizeY.get(JAVA_INT, 0), pGroupSizeZ.get(JAVA_INT, 0)));
481 } else {
482 debug("use localSizes", geometry.toString());
483 check(zeKernelSetGroupSize(kernelHandle, localSizes[0], localSizes[1], localSizes[2]));
484 }
485 return kernelHandle;
486 }
487
488 private void setKernelArg(Arg arg, int ordinal, MemorySegment commandListHandle, MemorySegment kernelHandle, boolean shared) {
489 MemorySegment dataSegment = arg.dataSegment();
490 Class<?> cls = arg.cls();
491 debug("ordinal = %d, cls = %s, data = %s", ordinal, cls.getSimpleName(), dataSegment);
492 if (shared) { // shared memory
493 check(zeKernelSetArgumentValue(kernelHandle, ordinal, dataSegment.byteSize(), dataSegment));
494 }
495 else if (cls == byte[].class || cls == short[].class || cls == int[].class || cls == float[].class || cls.getSimpleName().equals("NativeMemorySegmentImpl")) {
496 check(zeCommandListAppendMemoryPrefetch(commandListHandle, dataSegment, dataSegment.byteSize()));
497 check(zeCommandListAppendMemAdvise(commandListHandle, deviceHandle, dataSegment, dataSegment.byteSize(), ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION()));
498 MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment);
499 check(zeKernelSetArgumentValue(kernelHandle, ordinal, ADDRESS.byteSize(), pDataSegment));
500 }
501 else if (cls == Short.class) {
502 MemorySegment pArgValue = arena.allocateFrom(JAVA_SHORT, (short)arg.value());
503 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Short.BYTES, pArgValue));
504 }
505 else if (VectorSpecies.class.isAssignableFrom(cls)) {
506 MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, FloatVector.SPECIES_256.length());
507 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue));
508 }
509 else if (cls == Integer.class || cls == Boolean.class) {
510 MemorySegment pArgValue = arena.allocateFrom(JAVA_INT, (int)arg.value());
511 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Integer.BYTES, pArgValue));
512 }
513 else if (cls == Long.class) {
514 MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, (long)arg.value());
515 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Long.BYTES, pArgValue));
516 }
517 else if (cls == Float.class) {
518 MemorySegment pArgValue = arena.allocateFrom(JAVA_LONG, Float.floatToIntBits((float)arg.value()));
519 check(zeKernelSetArgumentValue(kernelHandle, ordinal, Float.BYTES, pArgValue));
520 }
521 else if (cls == GPU.Index.class) {
522 MemorySegment pDataSegment = arena.allocateFrom(ADDRESS, dataSegment);
523 check(zeKernelSetArgumentValue(kernelHandle, ordinal, 24, pDataSegment));
524 }
525 else throw new RuntimeException("unsupported type: " + cls);
526 }
527
528 private MemorySegment createModule(String moduleName, MemorySegment spirvCode) {
529 MemorySegment pModuleHandle = arena.allocate(ze_module_handle_t);
530 MemorySegment moduleDesc = arena.allocate(ze_module_desc_t.layout());
531 ze_module_desc_t.stype(moduleDesc, ZE_STRUCTURE_TYPE_MODULE_DESC());
532 ze_module_desc_t.format(moduleDesc, ZE_MODULE_FORMAT_IL_SPIRV());
533 ze_module_desc_t.pInputModule(moduleDesc, spirvCode);
534 ze_module_desc_t.inputSize(moduleDesc, spirvCode.byteSize());
535 ze_module_desc_t.pBuildFlags(moduleDesc, arena.allocateFrom(""));
536 MemorySegment buildLogHandle = arena.allocate(ze_module_build_log_handle_t);
537 check(zeModuleCreate(contextHandle, deviceHandle, moduleDesc, pModuleHandle, buildLogHandle));
538 MemorySegment moduleHandle = pModuleHandle.get(ADDRESS, 0);
539 return moduleHandle;
540 }
541
542 public MemorySegment allocateSharedSegment(long byteSize) {
543 return allocateSharedSegment(contextHandle(), deviceHandle(), byteSize, Arena.global());
544 }
545
546 public static MemorySegment allocateSharedSegment(MemorySegment contextHandle, MemorySegment deviceHandle, long byteSize, Arena arena) {
547 MemorySegment pDeviceMemAllocDesc = arena.allocate(ze_device_mem_alloc_desc_t.layout());
548 ze_device_mem_alloc_desc_t.stype(pDeviceMemAllocDesc, ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC());
549 ze_device_mem_alloc_desc_t.ordinal(pDeviceMemAllocDesc, 0);
550 MemorySegment pHostMemAllocDesc = arena.allocate(ze_host_mem_alloc_desc_t.layout());
551 ze_host_mem_alloc_desc_t.stype(pHostMemAllocDesc, ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC());
552 MemorySegment pBuffer = arena.allocate(ADDRESS);
553 check(zeMemAllocShared(contextHandle, pDeviceMemAllocDesc, pHostMemAllocDesc, byteSize, 1, deviceHandle, pBuffer));
554 long address = pBuffer.get(JAVA_LONG, 0);
555 return MemorySegment.ofAddress(address).reinterpret(byteSize);
556 }
557
558 private static record KernelGeometry(int[] globalSizes, int[] localSizes) {
559 public KernelGeometry() {
560 this(new int[3], new int[] {512, 1, 1});
561 }
562
563 @Override
564 public String toString() {
565 return String.format("global: %s, local: %s", Arrays.toString(globalSizes), Arrays.toString(localSizes));
566 }
567 }
568
569 public void benchAddKernel() {
570 String jsonFileName = cacheDir + addKernelCache + "/add_kernel.json";
571 String moduleName = cacheDir + addKernelCache + "/add_kernel.spv";
572 JSONObject jsonObject = loadJson(jsonFileName);
573 String kernelName = jsonObject.getString("name");
574 int threads_per_warp = jsonObject.getInt("threads_per_warp");
575 int num_warps = jsonObject.getInt("num_warps");
576 int shared = jsonObject.getInt("shared");
577 Random rand = new Random();
578
579 Writer writer = new Writer("benchmark/vector_add_benchmark.txt");
580 writer.write("elementSize timeElapsed timeElapsedForRerun RTT gb/s \n");
581
582 for (int elementSize = (1 << 12); elementSize <= (1 << 28); elementSize <<= 1) {
583 int BLOCK_SIZE = 1024;
584 int gridSize = (elementSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
585
586 float[] input1 = new float[elementSize];
587 float[] input2 = new float[elementSize];
588 float[] output = new float[elementSize];
589 for (int i = 0; i < elementSize; i++) {
590 input1[i] = rand.nextFloat();
591 input2[i] = rand.nextFloat();
592 }
593 Object[] args = new Object[] {input1, input2, output, elementSize};
594
595 // warmup
596 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
597 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
598
599 int nTimes = 10;
600 Instant start = Instant.now();
601 for (int i = 0; i < nTimes; ++i)
602 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
603 Instant finish = Instant.now();
604 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
605 Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
606 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
607 writer.write(String.format("%d %.4f %.4f %.4f %.4f\n", elementSize, timeElapsedForRun, timeElapsedForRerun, RTT, (4f * 3f * elementSize / timeElapsedForRerun * 1e-6)));
608 }
609 writer.close();
610 }
611
612 public void benchSoftmaxKernel() {
613 String jsonFileName = cacheDir + softmaxKernelCache + "/softmax_kernel.json";
614 String moduleName = cacheDir + softmaxKernelCache + "/softmax_kernel.spv";
615 JSONObject jsonObject = loadJson(jsonFileName);
616 String kernelName = jsonObject.getString("name");
617 int threads_per_warp = jsonObject.getInt("threads_per_warp");
618 int num_warps = jsonObject.getInt("num_warps");
619 int shared = jsonObject.getInt("shared");
620 Random rand = new Random();
621
622 Writer writer = new Writer("benchmark/softmax_benchmark.txt");
623 writer.write("elementSizeX elementSizeY timeElapsed timeElapsedForRerun RTT gb/s \n");
624
625 for (int i = 2; i < 50; i++) {
626 int elementSizeX = 4096;
627 int elementSizeY = 128 * i;
628 int gridSize = elementSizeX;
629 float[] input = new float[elementSizeX * elementSizeY];
630 float[] output = new float[elementSizeX * elementSizeY];
631 byte[] sharedMemory = new byte[shared];
632 for (int j = 0; j < elementSizeX * elementSizeY; j++) {
633 input[j] = rand.nextFloat();
634 }
635 Object[] args = new Object[] {output, input, elementSizeY, elementSizeY, elementSizeY, sharedMemory};
636
637 // warmup
638 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
639 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
640
641 int nTimes = 10;
642 Instant start = Instant.now();
643 for (int j = 0; j < nTimes; ++j)
644 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
645 Instant finish = Instant.now();
646 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
647 Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
648 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
649 writer.write(String.format("%d %d %.4f %.4f %.4f %.4f\n", elementSizeX, elementSizeY, timeElapsedForRun, timeElapsedForRerun, RTT, (4 * 2 * 1e-9 * elementSizeX * elementSizeY / (timeElapsedForRerun * 1e-3))));
650 }
651 writer.close();
652 }
653
654 public void benchMatmulKernel() {
655 String jsonFileName = cacheDir + matmulKernelCache + "/matmul_kernel.json";
656 String moduleName = cacheDir + matmulKernelCache + "/matmul_kernel.spv";
657
658 JSONObject jsonObject = loadJson(jsonFileName);
659 String kernelName = jsonObject.getString("name");
660 int threads_per_warp = jsonObject.getInt("threads_per_warp");
661 int num_warps = jsonObject.getInt("num_warps");
662 int shared = jsonObject.getInt("shared");
663 Random rand = new Random();
664 int BLOCK_SIZE_M = 128, BLOCK_SIZE_N = 64;
665
666
667 Writer writer = new Writer("benchmark/matmul_benchmark.txt");
668 writer.write("M N K timeElapsed timeElapsedForRerun RTT TFLOPS \n");
669
670 for (int i = 2; i <= 64; i++) {
671 int M = 128 * i;
672 int N = 128 * i;
673 int K = 128 * i;
674 int gridSize = ((M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M) * ((N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N);
675
676 float[] a = new float[M * K];
677 float[] b = new float[K * N];
678 float[] c = new float[M * N];
679 byte[] sharedMemory = new byte[shared];
680 for (int j = 0; j < M * K; j++) {
681 a[j] = rand.nextFloat();
682 }
683 for (int j = 0; j < K * N; j++) {
684 b[j] = rand.nextFloat();
685 }
686 Object[] args = new Object[] {a, b, c, M, N, K, K, N, N, sharedMemory};
687
688 // warmup
689 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
690 this.timeElapsedForRun = this.timeElapsedForRerun = (double) 0;
691
692 int nTimes = 10;
693 Instant start = Instant.now();
694 for (int j = 0; j < nTimes; ++j)
695 run(kernelName, moduleName, args, threads_per_warp, num_warps, shared, gridSize);
696 Instant finish = Instant.now();
697 Double RTT = Duration.between(start, finish).toNanos() * 1e-6 / nTimes;
698 Double timeElapsedForRun = this.timeElapsedForRun / nTimes;
699 Double timeElapsedForRerun = this.timeElapsedForRerun / nTimes;
700 writer.write(String.format("%d %d %d %.4f %.4f %.4f %.4f\n", M, N, K, timeElapsedForRun, timeElapsedForRerun, RTT, (2 * 1e-12 * M * N * K / (timeElapsedForRerun * 1e-3))));
701 }
702 writer.close();
703 }
704
705 public static void main(String[] args) {
706 LevelZero lz = new LevelZero();
707 lz.test("add");
708 lz.test("softmax");
709 lz.test("matmul");
710 lz.benchAddKernel();
711 lz.benchSoftmaxKernel();
712 lz.benchMatmulKernel();
713 lz.clear();
714 }
715
716
717 public static class Arg {
718 private final String name;
719 private final Object value;
720 private final Class<?> cls;
721 private int size;
722 private boolean needsCleanup;
723 private MemorySegment dataSegment;
724
725 public static Arg createArg(LevelZero lz, String name, Object value) {
726 Arg arg = new Arg(name, value);
727 lz.provisionArg(arg);
728 return arg;
729 }
730
731 private Arg(String name, Object value) {
732 this.name = name;
733 this.cls = value.getClass();
734 this.value = value;
735 }
736
737 public String name() {
738 return name;
739 }
740
741 public Object value() {
742 return value;
743 }
744
745 public Class<?> cls() {
746 return cls;
747 }
748
749 public void setSize(int size) {
750 this.size = size;
751 }
752
753 public int size() {
754 return size;
755 }
756
757 public void setDataSegment(MemorySegment segment) {
758 dataSegment = segment;
759 }
760
761 public MemorySegment dataSegment() {
762 return dataSegment;
763 }
764
765 public void setNeedsCleanup(boolean needsCleanup) {
766 this.needsCleanup = needsCleanup;
767 }
768
769 public boolean needsCleanup() {
770 return needsCleanup;
771 }
772
773 public String toString() {
774 return String.format("name = %s, cls = %s", name, cls);
775 }
776 }
777
778 private JSONObject loadJson(String fileName) {
779 StringBuilder jsonString = new StringBuilder();
780 try (BufferedReader br = new BufferedReader(new FileReader(fileName)))
781 {
782 String line;
783 while ((line = br.readLine()) != null) {
784 jsonString.append(line);
785 }
786 } catch (IOException e) {
787 e.printStackTrace();
788 }
789 JSONObject jsonObject = new JSONObject(jsonString.toString());
790 return jsonObject;
791 }
792
793 private class Test {
794 public static float[] add(float[] a, float[] b, int SIZE) {
795 float[] output = new float[SIZE];
796 for (int i = 0; i < SIZE; ++i)
797 output[i] = a[i] + b[i];
798 return output;
799 }
800 public static float[] softmax(float[] a, int X, int Y) {
801 float[] output = new float[X * Y];
802 for (int i = 0; i < X; ++i) {
803 float max = Float.MIN_VALUE;
804 for (int j = 0; j < Y; ++j) {
805 max = Math.max(max, a[i * Y + j]);
806 }
807 float sum = 0;
808 for (int j = 0; j < Y; ++j) {
809 output[i * Y + j] = (float)Math.exp(a[i * Y + j] - max);
810 sum += output[i * Y + j];
811 }
812 for (int j = 0; j < Y; ++j) {
813 output[i * Y + j] /= sum;
814 }
815 }
816 return output;
817 }
818 public static float[] matmul(float[] a, float[] b, int M, int N, int K) {
819 float[] output = new float[M * N];
820 for (int i = 0; i < M; i++) {
821 for (int j = 0; j < N; j++) {
822 float tmp = 0;
823 for (int k = 0; k < K; k++) {
824 tmp += a[i * K + k] * b[k * N + j];
825 }
826 output[i * N + j] = tmp;
827 }
828 }
829 return output;
830 }
831 public static void check(float[] expected, float[] output) {
832 for (int i = 0; i < expected.length; i++) {
833 if (Math.abs(expected[i] - output[i]) > 1e-2) {
834 System.out.printf("Mismatch at %d: %f != %f%n", i, expected[i], output[i]);
835 throw new RuntimeException("Mismatch");
836 }
837 }
838 System.out.println("Test passed");
839 }
840 }
841
842 private class Writer {
843 private final String fileName;
844 private final FileWriter writer;
845
846 public Writer(String fileName) {
847 this.fileName = fileName;
848 try {
849 writer = new FileWriter(fileName, false);
850 } catch (IOException e) {
851 throw new RuntimeException(e);
852 }
853 }
854
855 public void write(String line) {
856 try {
857 writer.write(line);
858 } catch (IOException e) {
859 throw new RuntimeException(e);
860 }
861 }
862
863 public void close() {
864 try {
865 writer.close();
866 } catch (IOException e) {
867 throw new RuntimeException(e);
868 }
869 }
870 }
871 }