1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.NDRange;
 30 import hat.KernelContext;
 31 import hat.backend.Backend;
 32 import hat.buffer.F32Array;
 33 import hat.device.DeviceSchema;
 34 import hat.device.NonMappableIface;
 35 import optkl.ifacemapper.MappableIface;
 36 import jdk.incubator.code.Reflect;
 37 import hat.test.annotation.HatTest;
 38 import hat.test.exceptions.HATAsserts;
 39 
 40 import java.lang.invoke.MethodHandles;
 41 
 42 public class TestLocal {
 43 
 44     private interface MySharedArray extends NonMappableIface {
 45         void array(long index, float value);
 46         float array(long index);
 47 
 48         DeviceSchema<MySharedArray> schema = DeviceSchema.of(MySharedArray.class,
 49                 builder -> builder.withArray("array", 16));
 50 
 51         static MySharedArray create(Accelerator accelerator) {
 52             return null;
 53         }
 54 
 55         static MySharedArray createLocal() {
 56             return create(new Accelerator(MethodHandles.lookup(), Backend.FIRST));
 57         }
 58     }
 59 
 60     @Reflect
 61     private static void compute(@MappableIface.RO KernelContext kernelContext, @MappableIface.RW F32Array data) {
 62         MySharedArray mySharedArray = MySharedArray.createLocal();
 63         int lix = kernelContext.lix;
 64         int blockId = kernelContext.bix;
 65         int blockSize = kernelContext.lsx;
 66         mySharedArray.array(lix, lix);
 67         kernelContext.barrier();
 68         data.array(lix + (long) blockId * blockSize, mySharedArray.array(lix));
 69     }
 70 
 71     @Reflect
 72     private static void myCompute(@MappableIface.RO ComputeContext computeContext, @MappableIface.RW F32Array data) {
 73         computeContext.dispatchKernel(NDRange.of1D(32,16),
 74                 kernelContext -> compute(kernelContext, data)
 75         );
 76     }
 77 
 78     @HatTest
 79     @Reflect
 80     public void testLocal() {
 81         Accelerator accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
 82         F32Array data = F32Array.create(accelerator, 32);
 83         accelerator.compute(computeContext -> {
 84             TestLocal.myCompute(computeContext, data);
 85         });
 86 
 87         // Check result
 88         boolean isCorrect = true;
 89         int jIndex = 0;
 90         for (int i = 0; i < data.length(); i++) {
 91             if (data.array(i) != jIndex) {
 92                 isCorrect = false;
 93                 break;
 94             }
 95             jIndex++;
 96             if (jIndex == 16) {
 97                 jIndex = 0;
 98             }
 99         }
100         HATAsserts.assertTrue(isCorrect);
101     }
102 
103 }