1 /*
  2  * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.Accelerator.Compute;
 29 import hat.ComputeContext;
 30 import hat.HATMath;
 31 import hat.KernelContext;
 32 import hat.NDRange;
 33 import hat.backend.Backend;
 34 import hat.device.DeviceSchema;
 35 import hat.device.NonMappableIface;
 36 import hat.test.annotation.HatTest;
 37 import hat.test.exceptions.HATAsserts;
 38 import jdk.incubator.code.Reflect;
 39 import optkl.ifacemapper.BoundSchema;
 40 import optkl.ifacemapper.Buffer;
 41 import optkl.ifacemapper.MappableIface.RO;
 42 import optkl.ifacemapper.MappableIface.RW;
 43 import optkl.ifacemapper.MappableIface.WO;
 44 import optkl.ifacemapper.Schema;
 45 
 46 import java.lang.invoke.MethodHandles;
 47 
 48 import static hat.test.TestDFT.ArrayComplex.Complex;
 49 import static hat.test.TestDFT.ArrayComplex.create;
 50 
 51 public class TestDFT {
 52 
 53     public interface ArrayComplex extends Buffer {
 54         int length();
 55 
 56         interface Complex extends Struct {
 57             float real();
 58             float imag();
 59             void real(float real);
 60             void imag(float imag);
 61         }
 62 
 63         Complex complex(long index);
 64 
 65         Schema<ArrayComplex> schema = Schema.of(ArrayComplex.class,
 66                 complex ->
 67                         complex.arrayLen("length")
 68                                 .array("complex",
 69                                         array -> array.fields("real", "imag")));
 70 
 71         static ArrayComplex create(Accelerator accelerator, int length) {
 72             return BoundSchema.of(accelerator, schema, length).allocate();
 73         }
 74     }
 75 
 76     @Reflect
 77     public static void dftKernel(KernelContext kc,
 78                                   ArrayComplex input,
 79                                   ArrayComplex output) {
 80         int size = input.length();
 81         int idx = kc.gix;
 82         if (idx < kc.gsx) {
 83             float sumReal = 0.0f;
 84             float sumImag = 0.0f;
 85             for (int k = 0; k < size; k++) {
 86                 float angle = -2 * HATMath.PI * ((k * idx) % size) / size;
 87                 Complex complexInput = input.complex(k);
 88                 float cReal = HATMath.cosf(angle);
 89                 float cImag = HATMath.sinf(angle);
 90                 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
 91                 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
 92             }
 93             Complex complexOutput = output.complex(idx);
 94             complexOutput.real(sumReal);
 95             complexOutput.imag(sumImag);
 96         }
 97     }
 98 
 99     private static void dftJava(ArrayComplex input, ArrayComplex output) {
100         int size = input.length();
101         for (int k = 0; k < size; k++) {
102             Complex complexOutput = output.complex(k);
103             complexOutput.real(0.0f);
104             complexOutput.imag(0.0f);
105             float sumReal = 0.0f;
106             float sumImag = 0.0f;
107             for (int j = 0; j < size; j++) {
108                 float angle = -2 * HATMath.PI * ((j * k) % size) / size;
109                 Complex complexInput = input.complex(j);
110                 float cReal = HATMath.cosf(angle);
111                 float cImag = HATMath.sinf(angle);
112                 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
113                 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
114             }
115             complexOutput.real(sumReal);
116             complexOutput.imag(sumImag);
117         }
118     }
119 
120     @Reflect
121     private static void dftCompute(@RO ComputeContext cc,
122                                    @RO ArrayComplex input,
123                                    @WO ArrayComplex output) {
124         var range = NDRange.of1D(input.length(), 128);
125         cc.dispatchKernel(range, kernelContext -> dftKernel(kernelContext, input, output));
126     }
127 
128     @HatTest
129     public void testDFTWithOwnDS() {
130         var lookup = MethodHandles.lookup();
131         var accelerator = new Accelerator(lookup, Backend.FIRST);
132         final int size = 8192;
133         ArrayComplex input = create(accelerator, size);
134         ArrayComplex outputSeq = create(accelerator, size);
135         ArrayComplex outputHAT = create(accelerator, size);
136         accelerator.compute((@Reflect Compute) computeContext -> dftCompute(computeContext, input, outputHAT));
137         dftJava(input, outputSeq);
138         for (int i = 0; i < outputSeq.length(); i++) {
139             HATAsserts.assertEquals(outputSeq.complex(i).real(), outputHAT.complex(i).real(), 0.001f);
140             HATAsserts.assertEquals(outputSeq.complex(i).imag(), outputHAT.complex(i).imag(), 0.001f);
141         }
142     }
143 
144     // Simple test to check User Data Structures in Private Memory
145     public interface ArrayComplexPrivate extends NonMappableIface {
146         interface PrivateComplex extends NonMappableIface {
147             DeviceSchema<PrivateComplex> deviceSchema =
148                     DeviceSchema.of(PrivateComplex.class,
149                             complex -> complex.fields("real", "imag"));
150             float real();
151             float imag();
152             void real(float real);
153             void imag(float imag);
154         }
155         PrivateComplex complex(long index);
156         DeviceSchema<ArrayComplexPrivate> deviceSchema =
157                 DeviceSchema.of(ArrayComplexPrivate.class,
158                         complex -> complex.array("complex", 128, c2 -> c2.fields("real", "imag")));
159         static ArrayComplexPrivate createPrivate() {
160             return null;
161         }
162     }
163 
164     @Reflect
165     public static void testPrivateDS(KernelContext kc,
166                                       ArrayComplex input,
167                                       ArrayComplex output) {
168         int idx = kc.gix;
169         ArrayComplexPrivate priv = ArrayComplexPrivate.createPrivate();
170         ArrayComplexPrivate.PrivateComplex complex = priv.complex(0);
171         complex.real(1.0f);
172         complex.imag(2.0f);
173         Complex complexOutput = output.complex(idx);
174         complexOutput.real(complex.real());
175         complexOutput.imag(complex.imag());
176     }
177 
178 
179     @Reflect
180     private static void complexNumbersInPrivate(@RW ComputeContext cc,
181                                    @RO ArrayComplex input,
182                                    @WO ArrayComplex output) {
183         var range = NDRange.of1D(input.length(), 128);
184         cc.dispatchKernel(range, kernelContext -> testPrivateDS(kernelContext, input, output));
185     }
186 
187     @HatTest
188     public void complexNumbersInPrivate() {
189         var lookup = MethodHandles.lookup();
190         var accelerator = new Accelerator(lookup, Backend.FIRST);
191         final int size = 8192;
192         ArrayComplex input = create(accelerator, size);
193         ArrayComplex outputSeq = create(accelerator, size);
194         ArrayComplex outputHAT = create(accelerator, size);
195         accelerator.compute((@Reflect Compute) computeContext -> complexNumbersInPrivate(computeContext, input, outputHAT));
196         for (int i = 0; i < outputSeq.length(); i++) {
197             HATAsserts.assertEquals(1.0f, outputHAT.complex(i).real(), 0.001f);
198             HATAsserts.assertEquals(2.0f, outputHAT.complex(i).imag(), 0.001f);
199         }
200     }
201 
202 }