1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments.spirv;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.NDRange;
 30 import hat.KernelContext;
 31 import hat.backend.Backend;
 32 import hat.buffer.F32Array;
 33 
 34 import java.lang.invoke.MethodHandles;
 35 import jdk.incubator.code.CodeReflection;
 36 
 37 public class GetBackend {
 38 
 39     public static void getSpirvBakend() {
 40         Backend spirvBackend = Backend.getBackend((backend) -> {
 41             return backend.getClass().getSimpleName().equals("SpirvBackend");
 42         });
 43     }
 44 
 45     public static void getSpirvAccelerator() {
 46         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) -> {
 47             return backend.getClass().getSimpleName().equals("SpirvBackend");
 48         });
 49     }
 50 
 51     public static class MatrixMultiply {
 52 
 53         /*
 54           Original loop was
 55           for (int i = 0; i < size; i++) {
 56                 for (int j = 0; j < size; j++) {
 57                     float sum = 0f;
 58                     for (int k = 0; k < size; k++) {
 59                         sum += a[i * size + k] * b[k * size + j];
 60                         sum += a[i * size + k] * b[k * size + j];
 61                     }
 62                     c[i * size + j] = sum;
 63                 }
 64             }
 65 
 66            Converted to hat kernel
 67 
 68                 for (int j = 0; j < kid.max; j++) {
 69                     float sum = 0f;
 70                     for (int k = 0; k < kid.max; k++) {
 71                         sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 72                         sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 73                     }
 74                     c[kid.x * kid.max + j] = sum;
 75                 }
 76 
 77            We don't allow heap array access. So we use F32Array iface mapped segment
 78 
 79             Converted to hat kernel
 80 
 81                 for (int j = 0; j < kid.max; j++) {
 82                     float sum = 0f;
 83                     for (int k = 0; k < kid.max; k++) {
 84                         //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 85                         sum += a.array(kid.x * kid.max + k)*b.array(k * kid.max + j]);
 86                         //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 87                         sum += a.array(kid.x * kid.max + k) * b.array(k * kid.max + j);
 88                     }
 89                     //c[kid.x * kid.max + j] = sum;
 90                     c.array(kid.x * kid.max + j, sum);
 91                 }
 92 
 93          */
 94         @CodeReflection
 95         static void kernel(KernelContext kid, F32Array a, F32Array b, F32Array c) {
 96             for (int j = 0; j < kid.gsx; j++) {
 97                 float sum = 0f;
 98                 for (int k = 0; k < kid.gsx; k++) {
 99                     //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
100                     sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
101                     //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
102                     sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
103                 }
104                 //c[kid.x * kid.max + j] = sum;
105                 c.array(kid.gix * kid.gsx + j, sum);
106             }
107         }
108 
109         @CodeReflection
110         static void compute(ComputeContext computeContext, F32Array a, F32Array b, F32Array c, int size) {
111             computeContext.dispatchKernel(NDRange.of(size * size), kc -> MatrixMultiply.kernel(kc, a, b, c));
112         }
113 
114     }
115 
116     public static void main(String[] args) {
117         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) ->
118                 backend.getClass().getSimpleName().startsWith("Spirv")
119         );
120         var a = F32Array.create(accelerator, 100);
121         var b = F32Array.create(accelerator, 100);
122         var c = F32Array.create(accelerator, 100);
123         accelerator.compute(cc -> MatrixMultiply.compute(cc, a, b, c, 100));
124     }
125 
126 }