1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package experiments;
27
28 import hat.Accelerator;
29 import hat.Accelerator.Compute;
30 import hat.ComputeContext;
31 import hat.NDRange;
32 import hat.KernelContext;
33 import hat.backend.Backend;
34 import optkl.ifacemapper.BoundSchema;
35 import optkl.ifacemapper.Buffer;
36 import hat.buffer.F32Array;
37 import optkl.ifacemapper.MappableIface.RO;
38 import optkl.ifacemapper.MappableIface.RW;
39 import optkl.ifacemapper.Schema;
40 import jdk.incubator.code.Reflect;
41
42 import java.lang.invoke.MethodHandles;
43
44 /**
45 * Example of how to declare and use a custom data type in a method kernel on the GPU.
46 * This is just a proof of concept.
47 * <p>
48 * How to run?
49 * <code>
50 * HAT=SHOW_CODE java -cp job.jar hat.java exp ffi-opencl LocalArray
51 * HAT=SHOW_CODE java -cp job.jar hat.java exp ffi-cuda LocalArray
52 * </code>
53 * </p>
54 */
55 public class LocalArray {
56
57 private interface MyArray extends Buffer {
58 void array(long index, float value);
59 float array(long index);
60
61 Schema<MyArray> schema = Schema.of(MyArray.class,
62 myPrivateArray -> myPrivateArray
63 .array("array", 16));
64
65 static MyArray create(Accelerator accelerator) {
66 return BoundSchema.of(accelerator ,schema, 1).allocate();
67 }
68
69 static MyArray createLocal() {
70 return create(new Accelerator(MethodHandles.lookup(), Backend.FIRST));
71 }
72 }
73
74
75 @Reflect
76 private static void compute(@RO KernelContext kernelContext, @RW F32Array data) {
77 MyArray mySharedArray = MyArray.createLocal();
78 int lix = kernelContext.lix;
79 int blockId = kernelContext.bix;
80 int blockSize = kernelContext.lsx;
81 mySharedArray.array(lix, lix);
82 kernelContext.barrier();
83 data.array(lix + (long) blockId * blockSize, mySharedArray.array(lix));
84 }
85
86 @Reflect
87 private static void myCompute(@RO ComputeContext computeContext, @RW F32Array data) {
88 computeContext.dispatchKernel(NDRange.of1D(32,16),
89 kernelContext -> compute(kernelContext, data)
90 );
91 }
92
93 static void main(String[] args) {
94 System.out.println("Testing Shared Data Structures Mapping");
95 System.out.println("Schema description");
96 MyArray.schema.toText(System.out::print);
97 System.out.println(" ==================");
98
99 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
100 F32Array data = F32Array.create(accelerator, 32);
101 accelerator.compute((@Reflect Compute) computeContext -> {
102 LocalArray.myCompute(computeContext, data);
103 });
104
105 // Check result
106 boolean isCorrect = true;
107 int jIndex = 0;
108 for (int i = 0; i < data.length(); i++) {
109 System.out.println(data.array(i));
110 if (data.array(i) != jIndex) {
111 isCorrect = false;
112 break;
113 }
114 jIndex++;
115 if (jIndex == 16) {
116 jIndex = 0;
117 }
118 }
119 if (isCorrect) {
120 System.out.println("Correct result");
121 } else {
122 System.out.println("Wrong result");
123 }
124 }
125
126 }