1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments.spirv;
 26 
 27 import hat.Accelerator;
 28 import hat.Accelerator.Compute;
 29 import hat.ComputeContext;
 30 import hat.NDRange;
 31 import hat.KernelContext;
 32 import hat.backend.Backend;
 33 import hat.buffer.F32Array;
 34 
 35 import java.lang.invoke.MethodHandles;
 36 import jdk.incubator.code.Reflect;
 37 
 38 public class GetBackend {
 39 
 40     public static void getSpirvBakend() {
 41         Backend spirvBackend = Backend.getBackend((backend) -> {
 42             return backend.getClass().getSimpleName().equals("SpirvBackend");
 43         });
 44     }
 45 
 46     public static void getSpirvAccelerator() {
 47         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) -> {
 48             return backend.getClass().getSimpleName().equals("SpirvBackend");
 49         });
 50     }
 51 
 52     public static class MatrixMultiply {
 53 
 54         /*
 55           Original loop was
 56           for (int i = 0; i < size; i++) {
 57                 for (int j = 0; j < size; j++) {
 58                     float sum = 0f;
 59                     for (int k = 0; k < size; k++) {
 60                         sum += a[i * size + k] * b[k * size + j];
 61                         sum += a[i * size + k] * b[k * size + j];
 62                     }
 63                     c[i * size + j] = sum;
 64                 }
 65             }
 66 
 67            Converted to hat kernel
 68 
 69                 for (int j = 0; j < kid.max; j++) {
 70                     float sum = 0f;
 71                     for (int k = 0; k < kid.max; k++) {
 72                         sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 73                         sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 74                     }
 75                     c[kid.x * kid.max + j] = sum;
 76                 }
 77 
 78            We don't allow heap array access. So we use F32Array iface mapped segment
 79 
 80             Converted to hat kernel
 81 
 82                 for (int j = 0; j < kid.max; j++) {
 83                     float sum = 0f;
 84                     for (int k = 0; k < kid.max; k++) {
 85                         //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 86                         sum += a.array(kid.x * kid.max + k)*b.array(k * kid.max + j]);
 87                         //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
 88                         sum += a.array(kid.x * kid.max + k) * b.array(k * kid.max + j);
 89                     }
 90                     //c[kid.x * kid.max + j] = sum;
 91                     c.array(kid.x * kid.max + j, sum);
 92                 }
 93 
 94          */
 95         @Reflect
 96         static void kernel(KernelContext kid, F32Array a, F32Array b, F32Array c) {
 97             for (int j = 0; j < kid.gsx; j++) {
 98                 float sum = 0f;
 99                 for (int k = 0; k < kid.gsx; k++) {
100                     //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
101                     sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
102                     //sum += a[kid.x * kid.max + k] * b[k * kid.max + j];
103                     sum += a.array(kid.gix * kid.gsx + k) * b.array(k * kid.gsx + j);
104                 }
105                 //c[kid.x * kid.max + j] = sum;
106                 c.array(kid.gix * kid.gsx + j, sum);
107             }
108         }
109 
110         @Reflect
111         static void compute(ComputeContext computeContext, F32Array a, F32Array b, F32Array c, int size) {
112             computeContext.dispatchKernel(NDRange.of1D(size * size), kc -> MatrixMultiply.kernel(kc, a, b, c));
113         }
114 
115     }
116 
117     public static void main(String[] args) {
118         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), (backend) ->
119                 backend.getClass().getSimpleName().startsWith("Spirv")
120         );
121         var a = F32Array.create(accelerator, 100);
122         var b = F32Array.create(accelerator, 100);
123         var c = F32Array.create(accelerator, 100);
124         accelerator.compute((@Reflect Compute)
125                 cc -> MatrixMultiply.compute(cc, a, b, c, 100));
126     }
127 
128 }