1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package experiments.spirv;
26
27 import hat.Accelerator;
28 import hat.Accelerator.Compute;
29 import hat.ComputeContext;
30 import hat.NDRange;
31 import hat.KernelContext;
32 import hat.backend.Backend;
33 import hat.buffer.F32Array;
34
35 import java.lang.invoke.MethodHandles;
36 import jdk.incubator.code.Reflect;
37
38 public class GetBackend {
39
40 public static void getSpirvBakend() {
41 Backend spirvBackend = Backend.getBackend((backend) -> {
42 return backend.getClass().getSimpleName().equals("SpirvBackend");
43 });
44 }
45
46 public static void getSpirvAccelerator() {
47 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) -> {
48 return backend.getClass().getSimpleName().equals("SpirvBackend");
49 });
50 }
51
52 public static class MatrixMultiply {
53
54 /*
55 Original loop was
56 for (int i = 0; i < size; i++) {
57 for (int j = 0; j < size; j++) {
58 float sum = 0f;
59 for (int k = 0; k < size; k++) {
60 sum += a[i * size + k] * b[k * size + j];
61 sum += a[i * size + k] * b[k * size + j];
62 }
63 c[i * size + j] = sum;
64 }
65 }
66
67 Converted to hat kernel
68
69 for (int j = 0; j < kid.max; j++) {
70 float sum = 0f;
71 for (int k = 0; k < kid.max; k++) {
72 sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
73 sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
74 }
75 c[kid.x * kid.max + j] = sum;
76 }
77
78 We don't allow heap array access. So we use F32Array iface mapped segment
79
80 Converted to hat kernel
81
82 for (int j = 0; j < kid.max; j++) {
83 float sum = 0f;
84 for (int k = 0; k < kid.max; k++) {
85 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
86 sum += a.array(kid.x * kid.max + k)*b.array(k * kid.max + j]);
87 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
88 sum += a.array(kid.x * kid.max + k) * b.array(k * kid.max + j);
89 }
90 //c[kid.x * kid.max + j] = sum;
91 c.array(kid.x * kid.max + j, sum);
92 }
93
94 */
95 @Reflect
96 static void kernel(KernelContext kid, F32Array a, F32Array b, F32Array c) {
97 for (int j = 0; j < kid.gsx; j++) {
98 float sum = 0f;
99 for (int k = 0; k < kid.gsx; k++) {
100 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
101 sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
102 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
103 sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
104 }
105 //c[kid.x * kid.max + j] = sum;
106 c.array(kid.gix * kid.gsx + j, sum);
107 }
108 }
109
110 @Reflect
111 static void compute(ComputeContext computeContext, F32Array a, F32Array b, F32Array c, int size) {
112 computeContext.dispatchKernel(NDRange.of1D(size * size), kc -> MatrixMultiply.kernel(kc, a, b, c));
113 }
114
115 }
116
117 public static void main(String[] args) {
118 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) ->
119 backend.getClass().getSimpleName().startsWith("Spirv")
120 );
121 var a = F32Array.create(accelerator, 100);
122 var b = F32Array.create(accelerator, 100);
123 var c = F32Array.create(accelerator, 100);
124 accelerator.compute((@Reflect Compute)
125 cc -> MatrixMultiply.compute(cc, a, b, c, 100));
126 }
127
128 }