1
2 /*
3 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
4 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5 *
6 * This code is free software; you can redistribute it and/or modify it
7 * under the terms of the GNU General Public License version 2 only, as
8 * published by the Free Software Foundation. Oracle designates this
9 * particular file as subject to the "Classpath" exception as provided
10 * by Oracle in the LICENSE file that accompanied this code.
11 *
12 * This code is distributed in the hope that it will be useful, but WITHOUT
13 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 * version 2 for more details (a copy is included in the LICENSE file that
16 * accompanied this code).
17 *
18 * You should have received a copy of the GNU General Public License version
19 * 2 along with this work; if not, write to the Free Software Foundation,
20 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
21 *
22 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
23 * or visit www.oracle.com if you need additional information or have any
24 * questions.
25 */
26 package nbody;
27
28 import hat.Accelerator;
29 import hat.Accelerator.Compute;
30 import hat.ComputeContext;
31 import hat.KernelContext;
32 import hat.NDRange;
33 import hat.buffer.F32Array;
34 import hat.buffer.S32RGBAImage;
35
36 import jdk.incubator.code.Reflect;
37 import optkl.ifacemapper.MappableIface;
38
39 import javax.swing.JFrame;
40 import javax.swing.JPanel;
41 import java.awt.Dimension;
42 import java.awt.Graphics;
43 import java.awt.Graphics2D;
44 import java.awt.RenderingHints;
45 import java.awt.image.BufferedImage;
46 import java.awt.image.DataBufferInt;
47 import java.awt.image.WritableRaster;
48 import java.lang.invoke.MethodHandles;
49 import java.util.stream.IntStream;
50
51 import hat.types.vec3;
52 import static hat.types.vec3.*;
53 import hat.types.F32;
54 import static hat.types.F32.*;
55
56 public class Main extends JFrame implements Runnable {
57 public enum Mode {HAT, JavaMt, JavaSeq}
58
59 public static class DirectRasterPanel extends JPanel {
60 private BufferedImage bufferedImage;
61 private WritableRaster writableRaster;
62 private DataBufferInt dataBuffer;
63 private S32RGBAImage image;
64
65 public DirectRasterPanel(Accelerator acc, int width, int height) {
66 setPreferredSize(new Dimension(width, height));
67 this.image = S32RGBAImage.create(acc, width, height);
68 this.bufferedImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
69 this.writableRaster = bufferedImage.getRaster();
70 this.dataBuffer = ((DataBufferInt) writableRaster.getDataBuffer());
71 }
72
73 @Override
74 protected void paintComponent(Graphics g) {
75 Graphics2D g2d = (Graphics2D) g;
76 g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
77 g2d.drawImage(bufferedImage, 0, 0, null);
78 }
79 }
80
81 final Accelerator acc;
82 final Mode mode;
83
84 final float delT = .9f;
85 final int width;
86 final int height;
87 final float espSqr = 0.9f;
88
89 final float mass = .9f;
90 final int bodies;
91
92 final F32Array xyzPosFloatArr;
93 final F32Array xyzVelFloatArr;
94 DirectRasterPanel panel;
95
96 public Main(Accelerator acc, Mode mode, int bodyCount, int width, int height) {
97 super("NBody Sans OpenGL");
98 this.mode = mode;
99 this.acc = acc;
100 this.bodies = bodyCount;
101 this.width = width;
102 this.height = height;
103 this.xyzPosFloatArr = F32Array.create(acc, bodies * 4);
104 this.xyzVelFloatArr = F32Array.create(acc, bodies * 4);
105 panel = new DirectRasterPanel(acc, width, height);
106 final float maxDist = width / 2;
107
108 //System.out.println(bodies + " particles");
109
110 for (int body = 0; body < bodies; body++) {
111 final float theta = (float) (Math.random() * Math.PI * 2);
112 final float phi = (float) (Math.random() * Math.PI * 2);
113 final float radius = (float) (Math.random() * maxDist);
114
115 var radialx =
116 (float) (radius * Math.cos(theta) * Math.sin(phi) + width / 2);
117 var radialy =
118 (float) (radius * Math.sin(theta) * Math.sin(phi) + height / 2);
119 var radialz =
120 (float) (radius * Math.cos(phi) + Math.min(width, height) / 2);
121
122 xyzPosFloatArr.array(body * 4 + 0, radialx);
123 xyzPosFloatArr.array(body * 4 + 1, radialy);
124 xyzPosFloatArr.array(body * 4 + 2, radialz);
125 }
126
127 add(panel);
128 pack();
129 setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
130 setVisible(true);
131 new Thread(this).start();
132 }
133
134 @Reflect
135 public static void runSansVec(int bodyIdx, int bodies, @MappableIface.RW F32Array xyzPos, @MappableIface.RW F32Array xyzVel, @MappableIface.RW S32RGBAImage image, int imageWidth, float mass, float delT, float espSqr) {
136 final int FAR = 500;
137 final int MID = 300;
138 final int NEAR = 100;
139 float accx = 0.f;
140 float accy = 0.f;
141 float accz = 0.f;
142
143 final float myPosx = xyzPos.array(bodyIdx * 4 + 0);
144 final float myPosy = xyzPos.array(bodyIdx * 4 + 1);
145 final float myPosz = xyzPos.array(bodyIdx * 4 + 2);
146 final float myVelx = xyzVel.array(bodyIdx * 4 + 0);
147 final float myVely = xyzVel.array(bodyIdx * 4 + 1);
148 final float myVelz = xyzVel.array(bodyIdx * 4 + 2);
149
150 for (int body = 0; body < bodies; body++) {
151 final float dx = xyzPos.array(body * 4 + 0) - myPosx;
152 final float dy = xyzPos.array(body * 4 + 1) - myPosy;
153 final float dz = xyzPos.array(body * 4 + 2) - myPosz;
154 final float invDist = 1 / (float) Math.sqrt((dx * dx) + (dy * dy) + (dz * dz) + espSqr);
155 final float s = mass * invDist * invDist * invDist;
156 accx = accx + (s * dx);
157 accy = accy + (s * dy);
158 accz = accz + (s * dz);
159 }
160 accx = accx * delT;
161 accy = accy * delT;
162 accz = accz * delT;
163
164 float fx = myPosx + (myVelx + accx * .5f) * delT;
165 float fy = myPosy + (myVely + accx * .5f) * delT;
166 float fz = myPosz + (myVelz + accx * .5f) * delT;
167 xyzPos.array(bodyIdx * 4 + 0, fx);
168 xyzPos.array(bodyIdx * 4 + 1, fy);
169 xyzPos.array(bodyIdx * 4 + 2, fz);
170
171 xyzVel.array(bodyIdx * 4 + 0, myVelx + accx);
172 xyzVel.array(bodyIdx * 4 + 1, myVely + accy);
173 xyzVel.array(bodyIdx * 4 + 2, myVelz + accz);
174
175 int x = (int) fx;
176 int y = (int) fy;
177 int z = (int) fz;
178 if (x > 1 && x < imageWidth - 2 && y > 1 && y < imageWidth - 2) {
179 // Calculate brightness based on depth (Z)
180 int brightness = (255 - (z / imageWidth * 255));
181 int color = (brightness << 16) | (brightness << 8) | brightness;
182 int pos = ((y * imageWidth) + x);
183 image.data(pos, color);
184 if (z < FAR) {
185 image.data(pos + 1, color);
186 image.data(pos - 1, color);
187 image.data(pos + imageWidth, color);
188 image.data(pos - imageWidth, color);
189 if (z < MID) {
190 image.data(pos + imageWidth + 1, color);
191 image.data(pos + imageWidth - 1, color);
192 image.data(pos - imageWidth + 1, color);
193 image.data(pos - imageWidth - 1, color);
194 if (z < NEAR) {
195 image.data(pos + imageWidth * 2 + 2, color);
196 image.data(pos + imageWidth * 2 - 2, color);
197 image.data(pos - imageWidth * 2 + 2, color);
198 image.data(pos - imageWidth * 2 - 2, color);
199 }
200 }
201 }
202 }
203 }
204
205 @Reflect
206 public static void run(int bodyIdx, int bodies, @MappableIface.RW F32Array xyzPos, @MappableIface.RW F32Array xyzVel, @MappableIface.RW S32RGBAImage image, int imageWidth, float mass, float delT, float espSqr) {
207 final int FAR = 500;
208 final int MID = 300;
209 final int NEAR = 100;
210 var acc = vec3(0f);
211 var idx = bodyIdx * 4;
212 var myPos = vec3(xyzPos.array(idx),xyzPos.array(idx+1),xyzPos.array(idx+2));
213 var myVel = vec3(xyzVel.array(idx),xyzVel.array(idx+1),xyzVel.array(idx+2));
214 for (int i = 0; i < bodies*4; i+=4) {
215 var delta = sub(vec3(xyzPos.array(i),xyzPos.array(i+1),xyzPos.array(i+2)),myPos);
216 float invDist = inversesqrt(dot(delta,delta)+espSqr);
217 acc = vec3.add(acc,mul(delta,pow(invDist,3f)*mass));
218 }
219 acc = mul(acc,delT);
220 var f = vec3.add(myPos,mul(vec3.add(myVel,mul(acc,.5f)),delT));
221
222 xyzPos.array(idx, f.x());
223 xyzPos.array(idx+1, f.y());
224 xyzPos.array(idx+2, f.z());
225
226 xyzVel.array(idx, myVel.x() + acc.x());
227 xyzVel.array(idx+1, myVel.y() + acc.y());
228 xyzVel.array(idx+2, myVel.z() + acc.z());
229
230 int x = (int) f.x();
231 int y = (int) f.y();
232 int z = (int) f.z();
233 if (x > 1 && x < imageWidth - 2 && y > 1 && y < imageWidth - 2) {
234 // Calculate brightness based on depth (Z)
235 int brightness = (255 - (z / imageWidth * 255));
236 int color = (brightness << 16) | (brightness << 8) | brightness;
237 int pos = ((y * imageWidth) + x);
238 image.data(pos, color);
239 if (z < FAR) {
240 image.data(pos + 1, color);
241 image.data(pos - 1, color);
242 image.data(pos + imageWidth, color);
243 image.data(pos - imageWidth, color);
244 if (z < MID) {
245 image.data(pos + imageWidth + 1, color);
246 image.data(pos + imageWidth - 1, color);
247 image.data(pos - imageWidth + 1, color);
248 image.data(pos - imageWidth - 1, color);
249 if (z < NEAR) {
250 image.data(pos + imageWidth * 2 + 2, color);
251 image.data(pos + imageWidth * 2 - 2, color);
252 image.data(pos - imageWidth * 2 + 2, color);
253 image.data(pos - imageWidth * 2 - 2, color);
254 }
255 }
256 }
257 }
258 }
259
260 @Reflect
261
262 static public void nbodyKernel(
263 @MappableIface.RO KernelContext kc,
264 int bodies,
265 @MappableIface.RW F32Array xyzPos,
266 @MappableIface.RW F32Array xyzVel,
267 @MappableIface.RW S32RGBAImage image,
268 int imageWidth,
269 float mass,
270 float delT,
271 float espSqr
272 ) {
273 run(kc.gix, bodies, xyzPos, xyzVel, image, imageWidth, mass, delT, espSqr);
274 }
275
276 @Reflect
277
278 static public void clearImage(
279 @MappableIface.RO KernelContext kc,
280 @MappableIface.RW S32RGBAImage image
281 ) {
282 image.data(kc.gix, 0);
283 }
284
285 @Reflect
286 public static void nbodyCompute(
287 @MappableIface.RO ComputeContext cc,
288 int bodies,
289 @MappableIface.RW F32Array xyzPos,
290 @MappableIface.RW F32Array xyzVel,
291 @MappableIface.RW S32RGBAImage image,
292 int imageWidth,
293 float mass,
294 float delT,
295 float espSqr
296 ) {
297 float cmass = mass;
298 float cdelT = delT;
299 float cespSqr = espSqr;
300 int cbodies = bodies;
301 int cimageWidth = imageWidth;
302
303 cc.dispatchKernel(NDRange.of1D(imageWidth * image.height()), kc -> clearImage(kc, image));
304
305 cc.dispatchKernel(NDRange.of1D(bodies), kc -> nbodyKernel(kc, cbodies, xyzPos, xyzVel, image, cimageWidth, cmass, cdelT, cespSqr));
306 }
307
308
309 public void run() {
310 while (true) {
311 // long startNs = System.nanoTime();
312 float cmass = mass;
313 float cdelT = delT;
314 float cespSqr = espSqr;
315 int cbodies = bodies;
316 int cimageWidth = width;
317 F32Array cxyzPosFloatArr = xyzPosFloatArr;
318 F32Array cxyzVelFloatArr = xyzVelFloatArr;
319 S32RGBAImage cimage = panel.image;
320
321 switch (mode) {
322 case HAT -> {
323 acc.compute((@Reflect Compute)
324 cc -> nbodyCompute(cc, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
325 }
326 case JavaMt -> {
327 MappableIface.getMemorySegment(panel.image).fill((byte) 0x00); // Dont do this if using HAT! ;)
328 IntStream.range(0, bodies).parallel().forEach(
329 i -> run(i, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
330 }
331 case JavaSeq -> {
332 MappableIface.getMemorySegment(panel.image).fill((byte) 0x00); // Dont do this if using HAT! ;)
333 IntStream.range(0, bodies).forEach(
334 i -> run(i, cbodies, cxyzPosFloatArr, cxyzVelFloatArr, cimage, cimageWidth, cmass, cdelT, cespSqr));
335 }
336 }
337 panel.image.syncToRasterDataBuffer(panel.dataBuffer);
338
339 // long endNs = System.nanoTime();
340 // System.out.println((endNs - startNs) / 1000000 + "ms");
341 repaint();
342 try {
343 Thread.sleep(1);
344 } catch (Exception e) {
345 }
346 }
347 }
348
349 static void main(String[] args) {
350 new Main(new Accelerator(MethodHandles.lookup()),
351 args.length > 0 ?
352 switch (args[0]) {
353 case "HAT" -> Mode.HAT;
354 case "JavaMT" -> Mode.JavaMt;
355 case "JavaSeq" -> Mode.JavaSeq;
356 default -> throw new IllegalStateException("No such mode as " + args[0]);
357 } : Mode.HAT
358 , 4096 * 4, 1024, 1024);
359
360 }
361 }
362