1 
  2 /*
  3  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  5  *
  6  * This code is free software; you can redistribute it and/or modify it
  7  * under the terms of the GNU General Public License version 2 only, as
  8  * published by the Free Software Foundation.  Oracle designates this
  9  * particular file as subject to the "Classpath" exception as provided
 10  * by Oracle in the LICENSE file that accompanied this code.
 11  *
 12  * This code is distributed in the hope that it will be useful, but WITHOUT
 13  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 14  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 15  * version 2 for more details (a copy is included in the LICENSE file that
 16  * accompanied this code).
 17  *
 18  * You should have received a copy of the GNU General Public License version
 19  * 2 along with this work; if not, write to the Free Software Foundation,
 20  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 21  *
 22  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 23  * or visit www.oracle.com if you need additional information or have any
 24  * questions.
 25  */
 26 package nbody;
 27 
 28 import hat.Accelerator;
 29 import hat.Accelerator.Compute;
 30 import hat.ComputeContext;
 31 import hat.KernelContext;
 32 import hat.NDRange;
 33 import hat.buffer.F32Array;
 34 import hat.buffer.S32RGBAImage;
 35 
 36 import jdk.incubator.code.Reflect;
 37 import optkl.ifacemapper.MappableIface;
 38 
 39 import javax.swing.JFrame;
 40 import javax.swing.JPanel;
 41 import java.awt.Dimension;
 42 import java.awt.Graphics;
 43 import java.awt.Graphics2D;
 44 import java.awt.RenderingHints;
 45 import java.awt.image.BufferedImage;
 46 import java.awt.image.DataBufferInt;
 47 import java.awt.image.WritableRaster;
 48 import java.lang.invoke.MethodHandles;
 49 import java.util.stream.IntStream;
 50 
 51 import  hat.types.vec3;
 52 import static hat.types.vec3.*;
 53 import hat.types.F32;
 54 import static hat.types.F32.*;
 55 
 56 public class Main extends JFrame implements Runnable {
 57     public enum Mode {HAT, JavaMt, JavaSeq}
 58 
 59     public static class DirectRasterPanel extends JPanel {
 60         private BufferedImage bufferedImage;
 61         private WritableRaster writableRaster;
 62         private DataBufferInt dataBuffer;
 63         private S32RGBAImage image;
 64 
 65         public DirectRasterPanel(Accelerator acc, int width, int height) {
 66             setPreferredSize(new Dimension(width, height));
 67             this.image = S32RGBAImage.create(acc, width, height);
 68             this.bufferedImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
 69             this.writableRaster = bufferedImage.getRaster();
 70             this.dataBuffer = ((DataBufferInt) writableRaster.getDataBuffer());
 71         }
 72 
 73         @Override
 74         protected void paintComponent(Graphics g) {
 75             Graphics2D g2d = (Graphics2D) g;
 76             g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
 77             g2d.drawImage(bufferedImage, 0, 0, null);
 78         }
 79     }
 80 
 81     final Accelerator acc;
 82     final Mode mode;
 83 
 84     final float delT = .9f;
 85     final int width;
 86     final int height;
 87     final float espSqr = 0.9f;
 88 
 89     final float mass = .9f;
 90     final int bodies;
 91 
 92     final F32Array xyzPosFloatArr;
 93     final F32Array xyzVelFloatArr;
 94     DirectRasterPanel panel;
 95 
 96     public Main(Accelerator acc, Mode mode, int bodyCount, int width, int height) {
 97         super("NBody Sans OpenGL");
 98         this.mode = mode;
 99         this.acc = acc;
100         this.bodies = bodyCount;
101         this.width = width;
102         this.height = height;
103         this.xyzPosFloatArr = F32Array.create(acc, bodies * 4);
104         this.xyzVelFloatArr = F32Array.create(acc, bodies * 4);
105         panel = new DirectRasterPanel(acc, width, height);
106         final float maxDist = width / 2;
107 
108         //System.out.println(bodies + " particles");
109 
110         for (int body = 0; body < bodies; body++) {
111             final float theta = (float) (Math.random() * Math.PI * 2);
112             final float phi = (float) (Math.random() * Math.PI * 2);
113             final float radius = (float) (Math.random() * maxDist);
114 
115             var radialx =
116                     (float) (radius * Math.cos(theta) * Math.sin(phi) + width / 2);
117             var radialy =
118                     (float) (radius * Math.sin(theta) * Math.sin(phi) + height / 2);
119             var radialz =
120                     (float) (radius * Math.cos(phi) + Math.min(width, height) / 2);
121 
122             xyzPosFloatArr.array(body * 4 + 0, radialx);
123             xyzPosFloatArr.array(body * 4 + 1, radialy);
124             xyzPosFloatArr.array(body * 4 + 2, radialz);
125         }
126 
127         add(panel);
128         pack();
129         setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
130         setVisible(true);
131         new Thread(this).start();
132     }
133 
134     @Reflect
135     public static void runSansVec(int bodyIdx, int bodies, @MappableIface.RW F32Array xyzPos, @MappableIface.RW F32Array xyzVel, @MappableIface.RW S32RGBAImage image, int imageWidth, float mass, float delT, float espSqr) {
136         final int FAR = 500;
137         final int MID = 300;
138         final int NEAR = 100;
139         float accx = 0.f;
140         float accy = 0.f;
141         float accz = 0.f;
142 
143         final float myPosx = xyzPos.array(bodyIdx * 4 + 0);
144         final float myPosy = xyzPos.array(bodyIdx * 4 + 1);
145         final float myPosz = xyzPos.array(bodyIdx * 4 + 2);
146         final float myVelx = xyzVel.array(bodyIdx * 4 + 0);
147         final float myVely = xyzVel.array(bodyIdx * 4 + 1);
148         final float myVelz = xyzVel.array(bodyIdx * 4 + 2);
149 
150         for (int body = 0; body < bodies; body++) {
151             final float dx = xyzPos.array(body * 4 + 0) - myPosx;
152             final float dy = xyzPos.array(body * 4 + 1) - myPosy;
153             final float dz = xyzPos.array(body * 4 + 2) - myPosz;
154             final float invDist = 1 / (float) Math.sqrt((dx * dx) + (dy * dy) + (dz * dz) + espSqr);
155             final float s = mass * invDist * invDist * invDist;
156             accx = accx + (s * dx);
157             accy = accy + (s * dy);
158             accz = accz + (s * dz);
159         }
160         accx = accx * delT;
161         accy = accy * delT;
162         accz = accz * delT;
163 
164         float fx = myPosx + (myVelx + accx * .5f) * delT;
165         float fy = myPosy + (myVely + accx * .5f) * delT;
166         float fz = myPosz + (myVelz + accx * .5f) * delT;
167         xyzPos.array(bodyIdx * 4 + 0, fx);
168         xyzPos.array(bodyIdx * 4 + 1, fy);
169         xyzPos.array(bodyIdx * 4 + 2, fz);
170 
171         xyzVel.array(bodyIdx * 4 + 0, myVelx + accx);
172         xyzVel.array(bodyIdx * 4 + 1, myVely + accy);
173         xyzVel.array(bodyIdx * 4 + 2, myVelz + accz);
174 
175         int x = (int) fx;
176         int y = (int) fy;
177         int z = (int) fz;
178         if (x > 1 && x < imageWidth - 2 && y > 1 && y < imageWidth - 2) {
179             // Calculate brightness based on depth (Z)
180             int brightness = (255 - (z / imageWidth * 255));
181             int color = (brightness << 16) | (brightness << 8) | brightness;
182             int pos = ((y * imageWidth) + x);
183             image.data(pos, color);
184             if (z < FAR) {
185                 image.data(pos + 1, color);
186                 image.data(pos - 1, color);
187                 image.data(pos + imageWidth, color);
188                 image.data(pos - imageWidth, color);
189                 if (z < MID) {
190                     image.data(pos + imageWidth + 1, color);
191                     image.data(pos + imageWidth - 1, color);
192                     image.data(pos - imageWidth + 1, color);
193                     image.data(pos - imageWidth - 1, color);
194                     if (z < NEAR) {
195                         image.data(pos + imageWidth * 2 + 2, color);
196                         image.data(pos + imageWidth * 2 - 2, color);
197                         image.data(pos - imageWidth * 2 + 2, color);
198                         image.data(pos - imageWidth * 2 - 2, color);
199                     }
200                 }
201             }
202         }
203     }
204 
205     @Reflect
206     public static void run(int bodyIdx, int bodies, @MappableIface.RW F32Array xyzPos, @MappableIface.RW F32Array xyzVel, @MappableIface.RW S32RGBAImage image, int imageWidth, float mass, float delT, float espSqr) {
207         final int FAR = 500;
208         final int MID = 300;
209         final int NEAR = 100;
210         var acc = vec3(0f);
211         var idx = bodyIdx * 4;
212         var myPos = vec3(xyzPos.array(idx),xyzPos.array(idx+1),xyzPos.array(idx+2));
213         var myVel = vec3(xyzVel.array(idx),xyzVel.array(idx+1),xyzVel.array(idx+2));
214         for (int i = 0; i < bodies*4; i+=4) {
215             var delta = sub(vec3(xyzPos.array(i),xyzPos.array(i+1),xyzPos.array(i+2)),myPos);
216             float invDist = inversesqrt(dot(delta,delta)+espSqr);
217             acc = vec3.add(acc,mul(delta,pow(invDist,3f)*mass));
218         }
219         acc = mul(acc,delT);
220         var f = vec3.add(myPos,mul(vec3.add(myVel,mul(acc,.5f)),delT));
221 
222         xyzPos.array(idx, f.x());
223         xyzPos.array(idx+1, f.y());
224         xyzPos.array(idx+2, f.z());
225 
226         xyzVel.array(idx, myVel.x() + acc.x());
227         xyzVel.array(idx+1, myVel.y() + acc.y());
228         xyzVel.array(idx+2, myVel.z() + acc.z());
229 
230         int x = (int) f.x();
231         int y = (int) f.y();
232         int z = (int) f.z();
233         if (x > 1 && x < imageWidth - 2 && y > 1 && y < imageWidth - 2) {
234             // Calculate brightness based on depth (Z)
235             int brightness = (255 - (z / imageWidth * 255));
236             int color = (brightness << 16) | (brightness << 8) | brightness;
237             int pos = ((y * imageWidth) + x);
238             image.data(pos, color);
239             if (z < FAR) {
240                 image.data(pos + 1, color);
241                 image.data(pos - 1, color);
242                 image.data(pos + imageWidth, color);
243                 image.data(pos - imageWidth, color);
244                 if (z < MID) {
245                     image.data(pos + imageWidth + 1, color);
246                     image.data(pos + imageWidth - 1, color);
247                     image.data(pos - imageWidth + 1, color);
248                     image.data(pos - imageWidth - 1, color);
249                     if (z < NEAR) {
250                         image.data(pos + imageWidth * 2 + 2, color);
251                         image.data(pos + imageWidth * 2 - 2, color);
252                         image.data(pos - imageWidth * 2 + 2, color);
253                         image.data(pos - imageWidth * 2 - 2, color);
254                     }
255                 }
256             }
257         }
258     }
259 
260     @Reflect
261 
262     static public void nbodyKernel(
263             @MappableIface.RO KernelContext kc,
264             int bodies,
265             @MappableIface.RW F32Array xyzPos,
266             @MappableIface.RW F32Array xyzVel,
267             @MappableIface.RW S32RGBAImage image,
268             int imageWidth,
269             float mass,
270             float delT,
271             float espSqr
272     ) {
273         run(kc.gix, bodies, xyzPos, xyzVel, image, imageWidth, mass, delT, espSqr);
274     }
275 
276     @Reflect
277 
278     static public void clearImage(
279             @MappableIface.RO KernelContext kc,
280             @MappableIface.RW S32RGBAImage image
281     ) {
282         image.data(kc.gix, 0);
283     }
284 
285     @Reflect
286     public static void nbodyCompute(
287             @MappableIface.RO ComputeContext cc,
288             int bodies,
289             @MappableIface.RW F32Array xyzPos,
290             @MappableIface.RW F32Array xyzVel,
291             @MappableIface.RW S32RGBAImage image,
292             int imageWidth,
293             float mass,
294             float delT,
295             float espSqr
296     ) {
297         float cmass = mass;
298         float cdelT = delT;
299         float cespSqr = espSqr;
300         int cbodies = bodies;
301         int cimageWidth = imageWidth;
302 
303         cc.dispatchKernel(NDRange.of1D(imageWidth * image.height()), kc -> clearImage(kc, image));
304 
305         cc.dispatchKernel(NDRange.of1D(bodies), kc -> nbodyKernel(kc, cbodies, xyzPos, xyzVel, image, cimageWidth, cmass, cdelT, cespSqr));
306     }
307 
308 
309     public void run() {
310         while (true) {
311            // long startNs = System.nanoTime();
312             float cmass = mass;
313             float cdelT = delT;
314             float cespSqr = espSqr;
315             int cbodies = bodies;
316             int cimageWidth = width;
317             F32Array cxyzPosFloatArr = xyzPosFloatArr;
318             F32Array cxyzVelFloatArr = xyzVelFloatArr;
319             S32RGBAImage cimage = panel.image;
320 
321             switch (mode) {
322                 case HAT -> {
323                     acc.compute((@Reflect Compute)
324                             cc -> nbodyCompute(cc, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
325                 }
326                 case JavaMt -> {
327                     MappableIface.getMemorySegment(panel.image).fill((byte) 0x00); // Dont do this if using HAT! ;)
328                     IntStream.range(0, bodies).parallel().forEach(
329                             i -> run(i, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
330                 }
331                 case JavaSeq -> {
332                     MappableIface.getMemorySegment(panel.image).fill((byte) 0x00); // Dont do this if using HAT! ;)
333                     IntStream.range(0, bodies).forEach(
334                             i -> run(i, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
335                 }
336             }
337             panel.image.syncToRasterDataBuffer(panel.dataBuffer);
338 
339           //  long endNs = System.nanoTime();
340           //  System.out.println((endNs - startNs) / 1000000 + "ms");
341             repaint();
342             try {
343                 Thread.sleep(1);
344             } catch (Exception e) {
345             }
346         }
347     }
348 
349     static void main(String[] args) {
350         new Main(new Accelerator(MethodHandles.lookup()),
351                 args.length > 0 ?
352                         switch (args[0]) {
353                             case "HAT" -> Mode.HAT;
354                             case "JavaMT" -> Mode.JavaMt;
355                             case "JavaSeq" -> Mode.JavaSeq;
356                             default -> throw new IllegalStateException("No such mode as " + args[0]);
357                         } : Mode.HAT
358                 , 4096 * 4, 1024, 1024);
359 
360     }
361 }
362