1 /*
2 * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package dft;
26
27 import dft.Main.ComplexArray.Complex;
28 import hat.Accelerator;
29 import hat.Accelerator.Compute;
30 import hat.ComputeContext;
31 import hat.HATMath;
32 import hat.KernelContext;
33 import hat.NDRange;
34 import hat.backend.Backend;
35 import hat.buffer.F32Array;
36 import hat.examples.common.ParseArgs;
37 import jdk.incubator.code.Reflect;
38 import optkl.ifacemapper.BoundSchema;
39 import optkl.ifacemapper.Buffer;
40 import optkl.ifacemapper.MappableIface.RO;
41 import optkl.ifacemapper.MappableIface.RW;
42 import optkl.ifacemapper.MappableIface.WO;
43 import optkl.ifacemapper.Schema;
44
45 import java.lang.invoke.MethodHandles;
46 import java.util.ArrayList;
47 import java.util.List;
48 import java.util.Random;
49 import java.util.stream.IntStream;
50
51 import static hat.examples.common.StatUtils.computeAverage;
52 import static hat.examples.common.StatUtils.computeSpeedup;
53 import static hat.examples.common.StatUtils.dumpStatsToCSVFile;
54 import static hat.examples.common.StatUtils.printCheckResult;
55
56 /**
57 * How to run?
58 *
59 * <p>
60 * With the OpenCL Backend:
61 * <code>
62 * java -cp hat/job.jar hat.java run ffi-opencl dft --size=<size> --iterations=<iterations> --verbose
63 * </code>
64 * </p>
65 *
66 * <p>
67 * With the CUDA Backend:
68 * <code>
69 * java -cp hat/job.jar hat.java run ffi-cuda dft --size=<size> --iterations=<iterations> --verbose
70 * </code>
71 *
72 * <p>
73 * Link to DFT: <a href="https://en.wikipedia.org/wiki/Discrete_Fourier_transform">link</a>
74 * </p>
75 *
76 * </p>
77 */
78 public class Main {
79
80 public static final float DELTA = 0.001f;
81
82 // Use a custom data structure for dealing with Array of Complex Numbers
83 public interface ComplexArray extends Buffer {
84 int length();
85
86 interface Complex extends Struct {
87 float real();
88 float imag();
89 void real(float real);
90 void imag(float imag);
91 }
92
93 Complex complex(long index);
94
95 Schema<ComplexArray> schema = Schema.of(ComplexArray.class, complex ->
96 complex.arrayLen("length")
97 .array("complex",
98 array -> array.fields("real", "imag")));
99
100 static ComplexArray create(Accelerator accelerator, int length) {
101 return BoundSchema.of(accelerator, schema, length).allocate();
102 }
103 }
104
105 @Reflect
106 private static void dftKernel(KernelContext kc, ComplexArray input, ComplexArray output) {
107 int size = input.length();
108 int idx = kc.gix;
109 if (idx < kc.gsx) {
110 float sumReal = 0.0f;
111 float sumImag = 0.0f;
112 for (int k = 0; k < size; k++) {
113 float angle = -2 * HATMath.PI * ((k * idx) % size) / size;
114 Complex complexInput = input.complex(k);
115 float cReal = HATMath.native_cosf(angle);
116 float cImag = HATMath.native_sinf(angle);
117 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
118 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
119 }
120 Complex complexOutput = output.complex(idx);
121 complexOutput.real(sumReal);
122 complexOutput.imag(sumImag);
123 }
124 }
125
126 @Reflect
127 private static void dftCompute(@RW ComputeContext cc, @RO ComplexArray input, @WO ComplexArray output) {
128 var range = NDRange.of1D(input.length(), 256);
129 cc.dispatchKernel(range, kernelContext -> dftKernel(kernelContext, input, output));
130 }
131
132 @Reflect
133 private static void dftPlainKernel(KernelContext kc, F32Array inReal, F32Array inImag, F32Array outReal, F32Array outImag) {
134 int size = inReal.length();
135 int idx = kc.gix;
136 if (idx < kc.gsx) {
137 float sumReal = 0.0f;
138 float sumImag = 0.0f;
139 for (int k = 0; k < size; k++) {
140 float angle = -2 * HATMath.PI * ((idx * k) % size) / size;
141 float cReal = HATMath.native_cosf(angle);
142 float cImag = HATMath.native_sinf(angle);
143 sumReal += (inReal.array(k) * cReal) - (inImag.array(k) * cImag);
144 sumImag += (inReal.array(k) * cImag) + (inImag.array(k) * cReal);
145 }
146 outReal.array(idx, sumReal);
147 outImag.array(idx, sumImag);
148 }
149 }
150
151 @Reflect
152 private static void dftPlainCompute(@RW ComputeContext cc, @RO F32Array inReal, @RO F32Array inImag, @WO F32Array outReal, @WO F32Array outImag) {
153 var range = NDRange.of1D(inReal.length(), 256);
154 cc.dispatchKernel(range, kernelContext -> dftPlainKernel(kernelContext, inReal, inImag, outReal, outImag));
155 }
156
157 private static void dftJava(ComplexArray input, ComplexArray output) {
158 int size = input.length();
159 for (int k = 0; k < size; k++) {
160 Complex complexOutput = output.complex(k);
161 complexOutput.real(0.0f);
162 complexOutput.imag(0.0f);
163 float sumReal = 0.0f;
164 float sumImag = 0.0f;
165 for (int j = 0; j < size; j++) {
166 float angle = -2 * HATMath.PI * ((j * k) % size) / size;
167 Complex complexInput = input.complex(j);
168 float cReal = HATMath.cosf(angle);
169 float cImag = HATMath.sinf(angle);
170 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
171 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
172 }
173 complexOutput.real(sumReal);
174 complexOutput.imag(sumImag);
175 }
176 }
177
178 private static void dftJavaStreams(ComplexArray input, ComplexArray output) {
179 int size = input.length();
180 IntStream.range(0, size).parallel().forEach(idx -> {
181 float sumReal = 0.0f;
182 float sumImag = 0.0f;
183 for (int k = 0; k < size; k++) {
184 float angle = -2 * HATMath.PI * ((idx * k) % size) / size;
185 Complex complexInput = input.complex(k);
186 float cReal = HATMath.cosf(angle);
187 float cImag = HATMath.sinf(angle);
188 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
189 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
190 }
191 Complex complexOutput = output.complex(idx);
192 complexOutput.real(sumReal);
193 complexOutput.imag(sumImag);
194 });
195 }
196
197 private static boolean checkResult(ComplexArray expected, ComplexArray obtained) {
198 for (int i = 0; i < expected.length(); i++) {
199 if (Math.abs(expected.complex(i).real() - obtained.complex(i).real()) > DELTA) {
200 IO.println(expected.complex(i).real() + " vs " + obtained.complex(i).real());
201 return false;
202 }
203 if (Math.abs(expected.complex(i).imag() - obtained.complex(i).imag()) > DELTA) {
204 IO.println(expected.complex(i).imag() + " vs " + obtained.complex(i).imag());
205 return false;
206 }
207 }
208 return true;
209 }
210
211 // Just for debugging
212 private static void printSignal(ComplexArray signal) {
213 for (int i = 0; i < signal.length(); i++) {
214 IO.println(signal.complex(i).real() + "," + signal.complex(i).imag());
215 }
216 }
217
218 private static boolean checkResult(ComplexArray outputSeq, F32Array outReal, F32Array outImag) {
219 for (int i = 0; i < outputSeq.length(); i++) {
220 if (Math.abs(outputSeq.complex(i).real() - outReal.array(i)) > DELTA) {
221 IO.println(outputSeq.complex(i).real() + " vs " + outReal.array(i));
222 return false;
223 }
224 if (Math.abs(outputSeq.complex(i).imag() - outImag.array(i)) > DELTA) {
225 IO.println(outputSeq.complex(i).imag() + " vs " + outImag.array(i));
226 return false;
227 }
228 }
229 return true;
230 }
231
232 static void main(String[] args) {
233 IO.println("=========================================");
234 IO.println("Example: Discrete Fourier Transform (DFT)");
235 IO.println("=========================================");
236
237 int size = 32768;
238 int iterations = 100;
239 ParseArgs parseArgs = new ParseArgs(args);
240 ParseArgs.Options options = parseArgs.parseWithDefaults(size, iterations);
241
242 boolean verbose = options.verbose();
243 size = options.size();
244 iterations = options.iterations();
245 boolean skipSequential = options.skipSequential();
246 IO.println("Input Size = " + size);
247 IO.println("Num Iterations = " + iterations);
248
249 var lookup = MethodHandles.lookup();
250 var accelerator = new Accelerator(lookup, Backend.FIRST);
251 ComplexArray input = ComplexArray.create(accelerator, size);
252 ComplexArray outputSeq = ComplexArray.create(accelerator, size);
253 ComplexArray.create(accelerator, size);
254 ComplexArray outputStreams = ComplexArray.create(accelerator, size);
255 ComplexArray outputHAT = ComplexArray.create(accelerator, size);
256
257 F32Array inReal = F32Array.create(accelerator, size);
258 F32Array inImag = F32Array.create(accelerator, size);
259 F32Array outReal = F32Array.create(accelerator, size);
260 F32Array outImag = F32Array.create(accelerator, size);
261
262 // Initialize
263 Random r = new Random(71);
264 for (int i = 0; i < size; i++) {
265 Complex c = input.complex(i);
266 c.real(r.nextFloat());
267 c.imag(r.nextFloat());
268 inReal.array(i, c.real());
269 inImag.array(i, c.imag());
270 }
271
272 List<Long> timersJavaDFT = new ArrayList<>();
273 List<Long> timersStreams = new ArrayList<>();
274 List<Long> timersDFTHat = new ArrayList<>();
275 List<Long> timersDFTHatFlatten = new ArrayList<>();
276
277 // Run Java sequential version, DFT
278 if (!skipSequential) {
279 for (int i = 0; i < iterations; i++) {
280 long start = System.nanoTime();
281 dftJava(input, outputSeq);
282 long end = System.nanoTime();
283 timersJavaDFT.add((end - start));
284 if (verbose) {
285 IO.println("[Timer] Java DFT: " + (end - start));
286 }
287 }
288 }
289
290 // Run Java Parallel Stream Version of the DFT
291 for (int i = 0; i < iterations; i++) {
292 long start = System.nanoTime();
293 dftJavaStreams(input, outputStreams);
294 long end = System.nanoTime();
295 timersStreams.add((end-start));
296 if (verbose) {
297 IO.println("[Timer] Parallel Stream: " + (end-start));
298 }
299 }
300
301 // HAT: Initial version (DFT)
302 for (int i = 0; i < iterations; i++) {
303 long start = System.nanoTime();
304 accelerator.compute((@Reflect Compute) computeContext -> dftCompute(computeContext, input, outputHAT));
305 long end = System.nanoTime();
306 timersDFTHat.add((end-start));
307 if (verbose) {
308 IO.println("[Timer] HAT with CDS: " + (end-start));
309 }
310 }
311
312 // HAT: DFT using plain arrays instead of a custom data structure
313 for (int i = 0; i < iterations; i++) {
314 long start = System.nanoTime();
315 accelerator.compute((@Reflect Compute) computeContext -> dftPlainCompute(computeContext, inReal, inImag, outReal, outImag));
316 long end = System.nanoTime();
317 timersDFTHatFlatten.add((end - start));
318 if (verbose) {
319 IO.println("[Timer] HAT-Plain: " + (end - start));
320 }
321 }
322
323 // Check results
324 ComplexArray baseline = skipSequential ? outputStreams : outputSeq;
325 boolean isStreamCorrect = checkResult(baseline, outputStreams);
326 boolean isHATCorrect = checkResult(baseline, outputHAT);
327 boolean isHATPlainCorrect = checkResult(baseline, outReal, outImag);
328 printCheckResult(isStreamCorrect, "Java-Stream");
329 printCheckResult(isHATCorrect, "HAT-Naive");
330 printCheckResult(isHATPlainCorrect, "HAT-Plain");
331
332 // Print Performance Metrics
333 final int skip = iterations / 2;
334 double averageJavaTimer = computeAverage(timersJavaDFT, skip);
335 double averageJavaStreamTimer = computeAverage(timersStreams, skip);
336 double averageHATTimers = computeAverage(timersDFTHat, skip);
337 double averageHATTimersFlatten = computeAverage(timersDFTHatFlatten, skip);
338
339 IO.println("\nAverage elapsed time:");
340 IO.println("Average Java-Seq DFT : " + averageJavaTimer);
341 IO.println("Average Java-Streams DFT : " + averageJavaStreamTimer);
342 IO.println("Average HAT DFT : " + averageHATTimers);
343 IO.println("Average HAT DFT (Flatten) : " + averageHATTimersFlatten);
344
345 if (!skipSequential) {
346 IO.println("\nSpeedups vs Java:");
347 IO.println("Java / Java Parallel Stream : " + computeSpeedup(averageJavaTimer, averageJavaStreamTimer) + "x");
348 IO.println("Java / HAT : " + computeSpeedup(averageJavaTimer, averageHATTimers) + "x");
349 IO.println("Java / HAT-Flatten : " + computeSpeedup(averageJavaTimer, averageHATTimersFlatten) + "x");
350 }
351
352 IO.println("\nSpeedups vs Java Parallel Streams:");
353 IO.println("Java / HAT : " + computeSpeedup(averageJavaStreamTimer, averageHATTimers) + "x");
354 IO.println("Java / HAT-Flatten : " + computeSpeedup(averageJavaStreamTimer, averageHATTimersFlatten) + "x");
355
356 IO.println("\nSpeedups vs HAT-DFT:");
357 IO.println("HAT / HAT-Flatten : " + computeSpeedup(averageHATTimers, averageHATTimersFlatten) + "x");
358
359 // Write CSV table with all results
360 List<List<Long>> timers = List.of(timersJavaDFT, timersStreams, timersDFTHat, timersDFTHatFlatten);
361 List<String> header = List.of("Java-fp32-" + size, "Streams-fp32-" + size, "HAT-UDT-fp32-" + size, "HAT-Plain-fp32-" + size);
362 String fileName = "table-results-dft-" + size + ".csv";
363 if (skipSequential) {
364 timers = List.of(timersStreams, timersDFTHat, timersDFTHatFlatten);
365 header = List.of("Streams-fp32-" + size, "HAT-UDT-fp32-" + size, "HAT-Plain-fp32-" + size);
366 }
367 dumpStatsToCSVFile(timers, header, fileName);
368 }
369 }