1 /*
2 * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.Accelerator.Compute;
29 import hat.ComputeContext;
30 import hat.HATMath;
31 import hat.KernelContext;
32 import hat.NDRange;
33 import hat.backend.Backend;
34 import hat.device.DeviceSchema;
35 import hat.device.NonMappableIface;
36 import hat.test.annotation.HatTest;
37 import hat.test.exceptions.HATAsserts;
38 import jdk.incubator.code.Reflect;
39 import optkl.ifacemapper.BoundSchema;
40 import optkl.ifacemapper.Buffer;
41 import optkl.ifacemapper.MappableIface.RO;
42 import optkl.ifacemapper.MappableIface.RW;
43 import optkl.ifacemapper.MappableIface.WO;
44 import optkl.ifacemapper.Schema;
45
46 import java.lang.invoke.MethodHandles;
47
48 import static hat.test.TestDFT.ArrayComplex.Complex;
49 import static hat.test.TestDFT.ArrayComplex.create;
50
51 public class TestDFT {
52
53 public interface ArrayComplex extends Buffer {
54 int length();
55
56 interface Complex extends Struct {
57 float real();
58 float imag();
59 void real(float real);
60 void imag(float imag);
61 }
62
63 Complex complex(long index);
64
65 Schema<ArrayComplex> schema = Schema.of(ArrayComplex.class,
66 complex ->
67 complex.arrayLen("length")
68 .array("complex",
69 array -> array.fields("real", "imag")));
70
71 static ArrayComplex create(Accelerator accelerator, int length) {
72 return BoundSchema.of(accelerator, schema, length).allocate();
73 }
74 }
75
76 @Reflect
77 public static void dftKernel(KernelContext kc,
78 ArrayComplex input,
79 ArrayComplex output) {
80 int size = input.length();
81 int idx = kc.gix;
82 if (idx < kc.gsx) {
83 float sumReal = 0.0f;
84 float sumImag = 0.0f;
85 for (int k = 0; k < size; k++) {
86 float angle = -2 * HATMath.PI * ((k * idx) % size) / size;
87 Complex complexInput = input.complex(k);
88 float cReal = HATMath.cosf(angle);
89 float cImag = HATMath.sinf(angle);
90 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
91 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
92 }
93 Complex complexOutput = output.complex(idx);
94 complexOutput.real(sumReal);
95 complexOutput.imag(sumImag);
96 }
97 }
98
99 private static void dftJava(ArrayComplex input, ArrayComplex output) {
100 int size = input.length();
101 for (int k = 0; k < size; k++) {
102 Complex complexOutput = output.complex(k);
103 complexOutput.real(0.0f);
104 complexOutput.imag(0.0f);
105 float sumReal = 0.0f;
106 float sumImag = 0.0f;
107 for (int j = 0; j < size; j++) {
108 float angle = -2 * HATMath.PI * ((j * k) % size) / size;
109 Complex complexInput = input.complex(j);
110 float cReal = HATMath.cosf(angle);
111 float cImag = HATMath.sinf(angle);
112 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
113 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
114 }
115 complexOutput.real(sumReal);
116 complexOutput.imag(sumImag);
117 }
118 }
119
120 @Reflect
121 private static void dftCompute(@RO ComputeContext cc,
122 @RO ArrayComplex input,
123 @WO ArrayComplex output) {
124 var range = NDRange.of1D(input.length(), 128);
125 cc.dispatchKernel(range, kernelContext -> dftKernel(kernelContext, input, output));
126 }
127
128 @HatTest
129 public void testDFTWithOwnDS() {
130 var lookup = MethodHandles.lookup();
131 var accelerator = new Accelerator(lookup, Backend.FIRST);
132 final int size = 8192;
133 ArrayComplex input = create(accelerator, size);
134 ArrayComplex outputSeq = create(accelerator, size);
135 ArrayComplex outputHAT = create(accelerator, size);
136 accelerator.compute((@Reflect Compute) computeContext -> dftCompute(computeContext, input, outputHAT));
137 dftJava(input, outputSeq);
138 for (int i = 0; i < outputSeq.length(); i++) {
139 HATAsserts.assertEquals(outputSeq.complex(i).real(), outputHAT.complex(i).real(), 0.001f);
140 HATAsserts.assertEquals(outputSeq.complex(i).imag(), outputHAT.complex(i).imag(), 0.001f);
141 }
142 }
143
144 // Simple test to check User Data Structures in Private Memory
145 public interface ArrayComplexPrivate extends NonMappableIface {
146 interface PrivateComplex extends NonMappableIface {
147 DeviceSchema<PrivateComplex> deviceSchema =
148 DeviceSchema.of(PrivateComplex.class,
149 complex -> complex.fields("real", "imag"));
150 float real();
151 float imag();
152 void real(float real);
153 void imag(float imag);
154 }
155 PrivateComplex complex(long index);
156 DeviceSchema<ArrayComplexPrivate> deviceSchema =
157 DeviceSchema.of(ArrayComplexPrivate.class,
158 complex -> complex.array("complex", 128, c2 -> c2.fields("real", "imag")));
159 static ArrayComplexPrivate createPrivate() {
160 return null;
161 }
162 }
163
164 @Reflect
165 public static void testPrivateDS(KernelContext kc,
166 ArrayComplex input,
167 ArrayComplex output) {
168 int idx = kc.gix;
169 ArrayComplexPrivate priv = ArrayComplexPrivate.createPrivate();
170 ArrayComplexPrivate.PrivateComplex complex = priv.complex(0);
171 complex.real(1.0f);
172 complex.imag(2.0f);
173 Complex complexOutput = output.complex(idx);
174 complexOutput.real(complex.real());
175 complexOutput.imag(complex.imag());
176 }
177
178
179 @Reflect
180 private static void complexNumbersInPrivate(@RW ComputeContext cc,
181 @RO ArrayComplex input,
182 @WO ArrayComplex output) {
183 var range = NDRange.of1D(input.length(), 128);
184 cc.dispatchKernel(range, kernelContext -> testPrivateDS(kernelContext, input, output));
185 }
186
187 @HatTest
188 public void complexNumbersInPrivate() {
189 var lookup = MethodHandles.lookup();
190 var accelerator = new Accelerator(lookup, Backend.FIRST);
191 final int size = 8192;
192 ArrayComplex input = create(accelerator, size);
193 ArrayComplex outputSeq = create(accelerator, size);
194 ArrayComplex outputHAT = create(accelerator, size);
195 accelerator.compute((@Reflect Compute) computeContext -> complexNumbersInPrivate(computeContext, input, outputHAT));
196 for (int i = 0; i < outputSeq.length(); i++) {
197 HATAsserts.assertEquals(1.0f, outputHAT.complex(i).real(), 0.001f);
198 HATAsserts.assertEquals(2.0f, outputHAT.complex(i).imag(), 0.001f);
199 }
200 }
201
202 }