1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package experiments.spirv;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.NDRange;
30 import hat.KernelContext;
31 import hat.backend.Backend;
32 import hat.buffer.F32Array;
33
34 import java.lang.invoke.MethodHandles;
35 import jdk.incubator.code.CodeReflection;
36
37 public class GetBackend {
38
39 public static void getSpirvBakend() {
40 Backend spirvBackend = Backend.getBackend((backend) -> {
41 return backend.getClass().getSimpleName().equals("SpirvBackend");
42 });
43 }
44
45 public static void getSpirvAccelerator() {
46 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) -> {
47 return backend.getClass().getSimpleName().equals("SpirvBackend");
48 });
49 }
50
51 public static class MatrixMultiply {
52
53 /*
54 Original loop was
55 for (int i = 0; i < size; i++) {
56 for (int j = 0; j < size; j++) {
57 float sum = 0f;
58 for (int k = 0; k < size; k++) {
59 sum += a[i * size + k] * b[k * size + j];
60 sum += a[i * size + k] * b[k * size + j];
61 }
62 c[i * size + j] = sum;
63 }
64 }
65
66 Converted to hat kernel
67
68 for (int j = 0; j < kid.max; j++) {
69 float sum = 0f;
70 for (int k = 0; k < kid.max; k++) {
71 sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
72 sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
73 }
74 c[kid.x * kid.max + j] = sum;
75 }
76
77 We don't allow heap array access. So we use F32Array iface mapped segment
78
79 Converted to hat kernel
80
81 for (int j = 0; j < kid.max; j++) {
82 float sum = 0f;
83 for (int k = 0; k < kid.max; k++) {
84 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
85 sum += a.array(kid.x * kid.max + k)*b.array(k * kid.max + j]);
86 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
87 sum += a.array(kid.x * kid.max + k) * b.array(k * kid.max + j);
88 }
89 //c[kid.x * kid.max + j] = sum;
90 c.array(kid.x * kid.max + j, sum);
91 }
92
93 */
94 @CodeReflection
95 static void kernel(KernelContext kid, F32Array a, F32Array b, F32Array c) {
96 for (int j = 0; j < kid.gsx; j++) {
97 float sum = 0f;
98 for (int k = 0; k < kid.gsx; k++) {
99 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
100 sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
101 //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
102 sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
103 }
104 //c[kid.x * kid.max + j] = sum;
105 c.array(kid.gix * kid.gsx + j, sum);
106 }
107 }
108
109 @CodeReflection
110 static void compute(ComputeContext computeContext, F32Array a, F32Array b, F32Array c, int size) {
111 computeContext.dispatchKernel(NDRange.of(size * size), kc -> MatrixMultiply.kernel(kc, a, b, c));
112 }
113
114 }
115
116 public static void main(String[] args) {
117 Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) ->
118 backend.getClass().getSimpleName().startsWith("Spirv")
119 );
120 var a = F32Array.create(accelerator, 100);
121 var b = F32Array.create(accelerator, 100);
122 var c = F32Array.create(accelerator, 100);
123 accelerator.compute(cc -> MatrixMultiply.compute(cc, a, b, c, 100));
124 }
125
126 }