1 /*
  2  * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package dft;
 26 
 27 import dft.Main.ComplexArray.Complex;
 28 import hat.Accelerator;
 29 import hat.Accelerator.Compute;
 30 import hat.ComputeContext;
 31 import hat.HATMath;
 32 import hat.KernelContext;
 33 import hat.NDRange;
 34 import hat.backend.Backend;
 35 import hat.buffer.F32Array;
 36 import hat.examples.common.ParseArgs;
 37 import jdk.incubator.code.Reflect;
 38 import optkl.ifacemapper.BoundSchema;
 39 import optkl.ifacemapper.Buffer;
 40 import optkl.ifacemapper.MappableIface.RO;
 41 import optkl.ifacemapper.MappableIface.RW;
 42 import optkl.ifacemapper.MappableIface.WO;
 43 import optkl.ifacemapper.Schema;
 44 
 45 import java.lang.invoke.MethodHandles;
 46 import java.util.ArrayList;
 47 import java.util.List;
 48 import java.util.Random;
 49 import java.util.stream.IntStream;
 50 
 51 import static hat.examples.common.StatUtils.computeAverage;
 52 import static hat.examples.common.StatUtils.computeSpeedup;
 53 import static hat.examples.common.StatUtils.dumpStatsToCSVFile;
 54 import static hat.examples.common.StatUtils.printCheckResult;
 55 
 56 /**
 57  * How to run?
 58  *
 59  * <p>
 60  * With the OpenCL Backend:
 61  * <code>
 62  *     java -cp hat/job.jar hat.java run ffi-opencl dft --size=<size> --iterations=<iterations> --verbose
 63  * </code>
 64  * </p>
 65  *
 66  * <p>
 67  * With the CUDA Backend:
 68  * <code>
 69  *      java -cp hat/job.jar hat.java run ffi-cuda dft --size=<size> --iterations=<iterations> --verbose
 70  * </code>
 71  *
 72  * <p>
 73  * Link to DFT: <a href="https://en.wikipedia.org/wiki/Discrete_Fourier_transform">link</a>
 74  * </p>
 75  *
 76  * </p>
 77  */
 78 public class Main {
 79 
 80     public static final float DELTA = 0.001f;
 81 
 82     // Use a custom data structure for dealing with Array of Complex Numbers
 83     public interface ComplexArray extends Buffer {
 84         int length();
 85 
 86         interface Complex extends Struct {
 87             float real();
 88             float imag();
 89             void real(float real);
 90             void imag(float imag);
 91         }
 92 
 93         Complex complex(long index);
 94 
 95         Schema<ComplexArray> schema = Schema.of(ComplexArray.class, complex ->
 96                 complex.arrayLen("length")
 97                         .array("complex",
 98                                 array -> array.fields("real", "imag")));
 99 
100         static ComplexArray create(Accelerator accelerator, int length) {
101             return BoundSchema.of(accelerator, schema, length).allocate();
102         }
103     }
104 
105     @Reflect
106     private static void dftKernel(KernelContext kc, ComplexArray input, ComplexArray output) {
107         int size = input.length();
108         int idx = kc.gix;
109         if (idx < kc.gsx) {
110             float sumReal = 0.0f;
111             float sumImag = 0.0f;
112             for (int k = 0; k < size; k++) {
113                 float angle = -2 * HATMath.PI * ((k * idx) % size) / size;
114                 Complex complexInput = input.complex(k);
115                 float cReal = HATMath.native_cosf(angle);
116                 float cImag = HATMath.native_sinf(angle);
117                 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
118                 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
119             }
120             Complex complexOutput = output.complex(idx);
121             complexOutput.real(sumReal);
122             complexOutput.imag(sumImag);
123         }
124     }
125 
126     @Reflect
127     private static void dftCompute(@RW ComputeContext cc, @RO ComplexArray input, @WO ComplexArray output) {
128         var range = NDRange.of1D(input.length(), 256);
129         cc.dispatchKernel(range, kernelContext -> dftKernel(kernelContext, input, output));
130     }
131 
132     @Reflect
133     private static void dftPlainKernel(KernelContext kc, F32Array inReal, F32Array inImag, F32Array outReal, F32Array outImag) {
134         int size = inReal.length();
135         int idx = kc.gix;
136         if (idx < kc.gsx) {
137             float sumReal = 0.0f;
138             float sumImag = 0.0f;
139             for (int k = 0; k < size; k++) {
140                 float angle = -2 * HATMath.PI * ((idx * k) % size) / size;
141                 float cReal = HATMath.native_cosf(angle);
142                 float cImag = HATMath.native_sinf(angle);
143                 sumReal += (inReal.array(k) * cReal) - (inImag.array(k) * cImag);
144                 sumImag += (inReal.array(k) * cImag) + (inImag.array(k) * cReal);
145             }
146             outReal.array(idx, sumReal);
147             outImag.array(idx, sumImag);
148         }
149     }
150 
151     @Reflect
152     private static void dftPlainCompute(@RW ComputeContext cc, @RO F32Array inReal, @RO F32Array inImag, @WO F32Array outReal, @WO F32Array outImag) {
153         var range = NDRange.of1D(inReal.length(), 256);
154         cc.dispatchKernel(range, kernelContext -> dftPlainKernel(kernelContext, inReal, inImag, outReal, outImag));
155     }
156 
157     private static void dftJava(ComplexArray input, ComplexArray output) {
158         int size = input.length();
159         for (int k = 0; k < size; k++) {
160             Complex complexOutput = output.complex(k);
161             complexOutput.real(0.0f);
162             complexOutput.imag(0.0f);
163             float sumReal = 0.0f;
164             float sumImag = 0.0f;
165             for (int j = 0; j < size; j++) {
166                 float angle = -2 * HATMath.PI * ((j * k) % size) / size;
167                 Complex complexInput = input.complex(j);
168                 float cReal = HATMath.cosf(angle);
169                 float cImag = HATMath.sinf(angle);
170                 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
171                 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
172             }
173             complexOutput.real(sumReal);
174             complexOutput.imag(sumImag);
175         }
176     }
177 
178     private static void dftJavaStreams(ComplexArray input, ComplexArray output) {
179         int size = input.length();
180         IntStream.range(0, size).parallel().forEach(idx -> {
181             float sumReal = 0.0f;
182             float sumImag = 0.0f;
183             for (int k = 0; k < size; k++) {
184                 float angle = -2 * HATMath.PI * ((idx * k) % size) / size;
185                 Complex complexInput = input.complex(k);
186                 float cReal = HATMath.cosf(angle);
187                 float cImag = HATMath.sinf(angle);
188                 sumReal += (complexInput.real() * cReal) - (complexInput.imag() * cImag);
189                 sumImag += (complexInput.real() * cImag) + (complexInput.imag() * cReal);
190             }
191             Complex complexOutput = output.complex(idx);
192             complexOutput.real(sumReal);
193             complexOutput.imag(sumImag);
194         });
195     }
196 
197     private static boolean checkResult(ComplexArray expected, ComplexArray obtained) {
198         for (int i = 0; i < expected.length(); i++) {
199             if (Math.abs(expected.complex(i).real() - obtained.complex(i).real()) > DELTA) {
200                 IO.println(expected.complex(i).real() + " vs " + obtained.complex(i).real());
201                 return false;
202             }
203             if (Math.abs(expected.complex(i).imag() - obtained.complex(i).imag()) > DELTA) {
204                 IO.println(expected.complex(i).imag() + " vs " + obtained.complex(i).imag());
205                 return false;
206             }
207         }
208         return true;
209     }
210 
211     // Just for debugging
212     private static void printSignal(ComplexArray signal) {
213         for (int i = 0; i < signal.length(); i++) {
214             IO.println(signal.complex(i).real() + "," + signal.complex(i).imag());
215         }
216     }
217 
218     private static boolean checkResult(ComplexArray outputSeq, F32Array outReal, F32Array outImag) {
219         for (int i = 0; i < outputSeq.length(); i++) {
220             if (Math.abs(outputSeq.complex(i).real() - outReal.array(i)) > DELTA) {
221                 IO.println(outputSeq.complex(i).real() + " vs " + outReal.array(i));
222                 return false;
223             }
224             if (Math.abs(outputSeq.complex(i).imag() - outImag.array(i)) > DELTA) {
225                 IO.println(outputSeq.complex(i).imag() + " vs " + outImag.array(i));
226                 return false;
227             }
228         }
229         return true;
230     }
231 
232     static void main(String[] args) {
233         IO.println("=========================================");
234         IO.println("Example: Discrete Fourier Transform (DFT)");
235         IO.println("=========================================");
236 
237         int size = 32768;
238         int iterations = 100;
239         ParseArgs parseArgs = new ParseArgs(args);
240         ParseArgs.Options options = parseArgs.parseWithDefaults(size, iterations);
241 
242         boolean verbose = options.verbose();
243         size = options.size();
244         iterations = options.iterations();
245         boolean skipSequential = options.skipSequential();
246         IO.println("Input Size     = " + size);
247         IO.println("Num Iterations = " + iterations);
248 
249         var lookup = MethodHandles.lookup();
250         var accelerator = new Accelerator(lookup, Backend.FIRST);
251         ComplexArray input = ComplexArray.create(accelerator, size);
252         ComplexArray outputSeq = ComplexArray.create(accelerator, size);
253         ComplexArray.create(accelerator, size);
254         ComplexArray outputStreams = ComplexArray.create(accelerator, size);
255         ComplexArray outputHAT = ComplexArray.create(accelerator, size);
256 
257         F32Array inReal = F32Array.create(accelerator, size);
258         F32Array inImag = F32Array.create(accelerator, size);
259         F32Array outReal = F32Array.create(accelerator, size);
260         F32Array outImag = F32Array.create(accelerator, size);
261 
262         // Initialize
263         Random r = new Random(71);
264         for (int i = 0; i < size; i++) {
265             Complex c = input.complex(i);
266             c.real(r.nextFloat());
267             c.imag(r.nextFloat());
268             inReal.array(i, c.real());
269             inImag.array(i, c.imag());
270         }
271 
272         List<Long> timersJavaDFT = new ArrayList<>();
273         List<Long> timersStreams = new ArrayList<>();
274         List<Long> timersDFTHat = new ArrayList<>();
275         List<Long> timersDFTHatFlatten = new ArrayList<>();
276 
277         // Run Java sequential version, DFT
278         if (!skipSequential) {
279             for (int i = 0; i < iterations; i++) {
280                 long start = System.nanoTime();
281                 dftJava(input, outputSeq);
282                 long end = System.nanoTime();
283                 timersJavaDFT.add((end - start));
284                 if (verbose) {
285                     IO.println("[Timer] Java DFT: " + (end - start));
286                 }
287             }
288         }
289 
290         // Run Java Parallel Stream Version of the DFT
291         for (int i = 0; i < iterations; i++) {
292             long start = System.nanoTime();
293             dftJavaStreams(input, outputStreams);
294             long end = System.nanoTime();
295             timersStreams.add((end-start));
296             if (verbose) {
297                 IO.println("[Timer] Parallel Stream: " + (end-start));
298             }
299         }
300 
301         // HAT: Initial version (DFT)
302         for (int i = 0; i < iterations; i++) {
303             long start = System.nanoTime();
304             accelerator.compute((@Reflect Compute) computeContext -> dftCompute(computeContext, input, outputHAT));
305             long end = System.nanoTime();
306             timersDFTHat.add((end-start));
307             if (verbose) {
308                 IO.println("[Timer] HAT with CDS: " + (end-start));
309             }
310         }
311 
312         // HAT: DFT using plain arrays instead of a custom data structure
313         for (int i = 0; i < iterations; i++) {
314             long start = System.nanoTime();
315             accelerator.compute((@Reflect Compute) computeContext -> dftPlainCompute(computeContext, inReal, inImag, outReal, outImag));
316             long end = System.nanoTime();
317             timersDFTHatFlatten.add((end - start));
318             if (verbose) {
319                 IO.println("[Timer] HAT-Plain: " + (end - start));
320             }
321         }
322 
323         // Check results
324         ComplexArray baseline = skipSequential ? outputStreams : outputSeq;
325         boolean isStreamCorrect = checkResult(baseline, outputStreams);
326         boolean isHATCorrect = checkResult(baseline, outputHAT);
327         boolean isHATPlainCorrect = checkResult(baseline, outReal, outImag);
328         printCheckResult(isStreamCorrect, "Java-Stream");
329         printCheckResult(isHATCorrect, "HAT-Naive");
330         printCheckResult(isHATPlainCorrect, "HAT-Plain");
331 
332         // Print Performance Metrics
333         final int skip = iterations / 2;
334         double averageJavaTimer = computeAverage(timersJavaDFT, skip);
335         double averageJavaStreamTimer = computeAverage(timersStreams, skip);
336         double averageHATTimers = computeAverage(timersDFTHat, skip);
337         double averageHATTimersFlatten = computeAverage(timersDFTHatFlatten, skip);
338 
339         IO.println("\nAverage elapsed time:");
340         IO.println("Average Java-Seq DFT           : " + averageJavaTimer);
341         IO.println("Average Java-Streams DFT       : " + averageJavaStreamTimer);
342         IO.println("Average HAT DFT                : " + averageHATTimers);
343         IO.println("Average HAT DFT (Flatten)      : " + averageHATTimersFlatten);
344 
345         if (!skipSequential) {
346             IO.println("\nSpeedups vs Java:");
347             IO.println("Java / Java Parallel Stream    : " + computeSpeedup(averageJavaTimer, averageJavaStreamTimer) + "x");
348             IO.println("Java / HAT                     : " + computeSpeedup(averageJavaTimer, averageHATTimers) + "x");
349             IO.println("Java / HAT-Flatten             : " + computeSpeedup(averageJavaTimer, averageHATTimersFlatten) + "x");
350         }
351 
352         IO.println("\nSpeedups vs Java Parallel Streams:");
353         IO.println("Java / HAT                     : " + computeSpeedup(averageJavaStreamTimer, averageHATTimers) + "x");
354         IO.println("Java / HAT-Flatten             : " + computeSpeedup(averageJavaStreamTimer, averageHATTimersFlatten) + "x");
355 
356         IO.println("\nSpeedups vs HAT-DFT:");
357         IO.println("HAT / HAT-Flatten              : " + computeSpeedup(averageHATTimers, averageHATTimersFlatten) + "x");
358 
359         // Write CSV table with all results
360         List<List<Long>> timers = List.of(timersJavaDFT, timersStreams, timersDFTHat, timersDFTHatFlatten);
361         List<String> header = List.of("Java-fp32-" + size, "Streams-fp32-" + size, "HAT-UDT-fp32-" + size, "HAT-Plain-fp32-" + size);
362         String fileName = "table-results-dft-" + size + ".csv";
363         if (skipSequential) {
364             timers = List.of(timersStreams, timersDFTHat, timersDFTHatFlatten);
365             header = List.of("Streams-fp32-" + size, "HAT-UDT-fp32-" + size, "HAT-Plain-fp32-" + size);
366         }
367         dumpStatsToCSVFile(timers, header, fileName);
368     }
369 }