1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package nbodygl.opencl;
 26 
 27 
 28 import hat.Accelerator;
 29 import hat.Accelerator.Compute;
 30 import hat.ComputeContext;
 31 import hat.KernelContext;
 32 import hat.types.Float4;
 33 import optkl.ifacemapper.BufferState;
 34 import hat.NDRange;
 35 import jdk.incubator.code.Reflect;
 36 import nbodygl.Mode;
 37 import nbodygl.NBodyGLWindow;
 38 import nbodygl.Universe;
 39 import wrap.opencl.CLPlatform;
 40 import wrap.opencl.CLWrapComputeContext;
 41 import wrap.opengl.GLTexture;
 42 
 43 import java.lang.foreign.Arena;
 44 import java.lang.invoke.MethodHandles;
 45 
 46 import static hat.backend.Backend.FIRST;
 47 import static optkl.ifacemapper.MappableIface.RO;
 48 import static optkl.ifacemapper.MappableIface.RW;
 49 import static opengl.opengl_h.glMatrixMode;
 50 import static opengl.opengl_h.glRasterPos2f;
 51 import static opengl.opengl_h.glScalef;
 52 import static opengl.opengl_h.glTexCoord2f;
 53 import static opengl.opengl_h.glVertex3f;
 54 import static opengl.opengl_h.glutBitmapCharacter;
 55 import static opengl.opengl_h.glutBitmapTimesRoman24$segment;
 56 import static opengl.opengl_h.glutSwapBuffers;
 57 import static opengl.opengl_h.glBindTexture;
 58 import static opengl.opengl_h.glClear;
 59 import static opengl.opengl_h.glClearColor;
 60 import static opengl.opengl_h.glColor3f;
 61 import static opengl.opengl_h.glDisable;
 62 import static opengl.opengl_h.glEnable;
 63 import static opengl.opengl_h.GL_COLOR_BUFFER_BIT;
 64 import static opengl.opengl_h.GL_DEPTH_BUFFER_BIT;
 65 import static opengl.opengl_h.GL_MODELVIEW;
 66 import static opengl.opengl_h.GL_TEXTURE_2D;
 67 
 68 
 69 public class OpenCLNBodyGLWindow extends NBodyGLWindow {
 70 
 71 
 72     @Reflect
 73     static public void nbodyKernel(KernelContext kc, Universe universe, float mass, float delT, float espSqr) {
 74         float accx = 0.0f;
 75         float accy = 0.0f;
 76         float accz = 0.0f;
 77         Universe.Body me = universe.body(kc.gix);
 78 
 79         for (int i = 0; i < kc.gsx; i++) {
 80             Universe.Body otherBody = universe.body(i);
 81             float dx = otherBody.x() - me.x();
 82             float dy = otherBody.y() - me.y();
 83             float dz = otherBody.z() - me.z();
 84             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
 85             float s = mass * invDist * invDist * invDist;
 86             accx = accx + (s * dx);
 87             accy = accy + (s * dy);
 88             accz = accz + (s * dz);
 89         }
 90         accx = accx * delT;
 91         accy = accy * delT;
 92         accz = accz * delT;
 93         me.x(me.x() + (me.vx() * delT) + accx * .5f * delT);
 94         me.y(me.y() + (me.vy() * delT) + accy * .5f * delT);
 95         me.z(me.z() + (me.vz() * delT) + accz * .5f * delT);
 96         me.vx(me.vx() + accx);
 97         me.vy(me.vy() + accy);
 98         me.vz(me.vz() + accz);
 99     }
100     interface F32x4Arr {
101 
102     }
103     @Reflect
104     static public void nbodyKernelf4(KernelContext kc, Universe universe, float mass, float delT, float espSqr) {
105         var acc = Float4.of(0,0,0,0);
106         var posArr = universe.posArrView();
107         var velArr = universe.velArrView();
108         var pos = posArr[kc.gix];
109         var vel = velArr[kc.gix];
110         for (int i = 0; i < kc.gix; i++) {
111             var delta = posArr[i].sub(pos);
112             var delSqr = delta.sqr();
113             var delSqrSum = delSqr.x() + delSqr.y() + delSqr.z();
114             var invDist = 1f / (float) Math.sqrt(delSqrSum + espSqr);
115             var invDistCubed = invDist * invDist * invDist;
116             acc = acc.add(delta.mul(mass * invDistCubed));
117         }
118         acc = acc.mul(delT);
119         pos = pos.add(vel.mul(delT).add(acc.mul(.5f * delT)));
120         vel = vel.add(acc);
121         posArr[kc.gix] = pos;
122         velArr[kc.gix] = vel;
123     }
124 
125     @Reflect
126     public static void nbodyCompute(@RO ComputeContext cc, @RW Universe universe, float mass, float delT, float espSqr) {
127         float cmass = mass;
128         float cdelT = delT;
129         float cespSqr = espSqr;
130 
131         cc.dispatchKernel(NDRange.of1D(universe.length()), kc -> nbodyKernel(kc, universe, cmass, cdelT, cespSqr));
132     }
133 
134     final CLPlatform.CLDevice.CLContext.CLProgram.CLKernel kernel;
135     final CLWrapComputeContext clWrapComputeContext;
136    // final CLWrapComputeContext.MemorySegmentState vel;
137    // final CLWrapComputeContext.MemorySegmentState pos;
138     final Accelerator accelerator;
139     final Universe universe;
140 
141     public OpenCLNBodyGLWindow(Arena arena, int width, int height, GLTexture particle, int bodyCount, Mode mode) {
142         super(arena, width, height, particle, bodyCount, mode);
143         if (mode == Mode.HAT) {
144             System.out.println("mode = "+mode);
145         }
146         final float maxDist = 80f;
147         accelerator = new Accelerator(MethodHandles.lookup(),FIRST);
148         universe = Universe.create(accelerator, bodyCount);
149         for (int body = 0; body < bodyCount; body++) {
150             Universe.Body b = universe.body(body);
151             final float theta = (float) (Math.random() * Math.PI * 2);
152             final float phi = (float) (Math.random() * Math.PI * 2);
153             final float radius = (float) (Math.random() * maxDist);
154 
155             // get random 3D coordinates in sphere
156             b.x((float) (radius * Math.cos(theta) * Math.sin(phi)));
157             b.y((float) (radius * Math.sin(theta) * Math.sin(phi)));
158             b.z((float) (radius * Math.cos(phi)));
159         }
160         if (mode.equals(Mode.HAT)) {
161             this.kernel=null;
162             this.clWrapComputeContext=null;
163         }else{
164             this.clWrapComputeContext = new CLWrapComputeContext(arena, 20);
165            // this.vel = clWrapComputeContext.register(xyzVelFloatArr.ptr());
166            // this.pos = clWrapComputeContext.register(xyzPosFloatArr.ptr());
167 
168             var platforms = CLPlatform.platforms(arena);
169             System.out.println("platforms " + platforms.size());
170             var platform = platforms.get(0);
171             platform.devices.forEach(device -> {
172                 System.out.println("   Compute Units     " + device.computeUnits());
173                 System.out.println("   Device Name       " + device.deviceName());
174                 System.out.println("   Device Vendor       " + device.deviceVendor());
175                 System.out.println("   Built In Kernels  " + device.builtInKernels());
176             });
177             var device = platform.devices.get(0);
178             System.out.println("   Compute Units     " + device.computeUnits());
179             System.out.println("   Device Name       " + device.deviceName());
180             System.out.println("   Device Vendor       " + device.deviceVendor());
181 
182             System.out.println("   Built In Kernels  " + device.builtInKernels());
183             var context = device.createContext();
184             String typedefs = """
185                      typedef struct Body_s{
186                           float x;
187                           float y;
188                           float z;
189                           float vx;
190                           float vy;
191                           float vz;
192                      } Body_t;
193 
194                      typedef struct Universe_s{
195                        int length;
196                        Body_t body[0];
197                      }Universe_t;
198                     """;
199             String code = switch (mode) {
200                 case Mode.OpenCL -> typedefs + """
201                         __kernel void nbody( __global Universe_t *universe, float mass, float delT, float espSqr ){
202                            __global Body_t * me = universe->body+get_global_id(0);
203                             float accx = 0.0;
204                             float accy = 0.0;
205                             float accz = 0.0;
206                             for (size_t i = 0; i < get_global_size(0); i++) {
207                                __global Body_t * otherBody = universe->body+i;
208                                 float dx = otherBody->x-me->x;
209                                 float dy = otherBody->y-me->y;
210                                 float dz = otherBody->z-me->z;
211                                 float invDist =  (float) 1.0/sqrt((float)((dx * dx) + (dy * dy) + (dz * dz) + espSqr));
212                                 float s = mass * invDist * invDist * invDist;
213                                 accx = accx + (s * dx);
214                                 accy = accy + (s * dy);
215                                 accz = accz + (s * dz);
216                             }
217                             accx = accx * delT;
218                             accy = accy * delT;
219                             accz = accz * delT;
220                             me->x = me->x+(me->vx*delT)+(accx * 0.5 * delT);
221                             me->y = me->y+(me->vy*delT)+(accy * 0.5 * delT);
222                             me->z = me->z+(me->vz*delT)+(accz * 0.5 * delT);
223                             me->vx = me->vx+accx;
224                             me->vy = me->vy+accy;
225                             me->vz = me->vz+accz;
226 
227                         }
228                         """;
229                    /* case Mode.OpenCL4 -> """
230                             __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
231                                 float4 acc = (0.0,0.0,0.0,0.0);
232                                 float4 myPos = xyzPos[get_global_id(0)];
233                                 float4 myVel = xyzVel[get_global_id(0)];
234                                 for (int i = 0; i < get_global_size(0); i++) {
235                                        float4 delta =  xyzPos[i] - myPos;
236                                        float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
237                                        float s = mass * invDist * invDist * invDist;
238                                        acc= acc + (s * delta);
239                                 }
240                                 acc = acc*delT;
241                                 myPos = myPos + (myVel * delT) + (acc * delT)/2;
242                                 myVel = myVel + acc;
243                                 xyzPos[get_global_id(0)] = myPos;
244                                 xyzVel[get_global_id(0)] = myVel;
245 
246                             }
247                             """;*/
248                 case Mode.OpenCL4 -> typedefs + """
249                         __kernel void nbody( __global float4 *xyzPos ,__global float4* xyzVel, float mass, float delT, float espSqr ){
250                             float4 acc = (0.0,0.0,0.0,0.0);
251                             float4 myPos = xyzPos[get_global_id(0)];
252                             float4 myVel = xyzVel[get_global_id(0)];
253                             for (int i = 0; i < get_global_size(0); i++) {
254                                    float4 delta =  xyzPos[i] - myPos;
255                                    float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
256                                    float s = mass * invDist * invDist * invDist;
257                                    acc= acc + (s * delta);
258                             }
259                             acc = acc*delT;
260                             myPos = myPos + (myVel * delT) + (acc * delT)/2;
261                             myVel = myVel + acc;
262                             xyzPos[get_global_id(0)] = myPos;
263                             xyzVel[get_global_id(0)] = myVel;
264 
265                         }
266                         """;
267                 /* Aspirational Java code
268                      void nbody(KernelContext kc, F32Arr xyzPosFloats ,F32Arr xyzVelFloats, float mass, float delT, float espSqr ){
269                                 float4 acc = float4.of(0.0,0.0,0.0,0.0);
270                                 float4[] xyzPos = xyzPosFloats.float4ArrayView();
271                                 float4[] xyzVel = xyzVelFloats.float4ArrayView();
272                                 float4 myPos = xyzPos[kc.gix];
273                                 float4 myVel = xyzVel[kc.gix];
274                                 for (int i = 0; i < kc.gsx; i++) {
275                                        float4 delta =  xyzPos[i].sub(myPos); // xyzPos[i] - myPos
276                                        float invDist =  (float) 1.0/sqrt((float)((delta.x * delta.x) + (delta.y * delta.y) + (delta.z * delta.z) + espSqr));
277                                        float s = mass * invDist * invDist * invDist;
278                                        acc  = acc.plus(delta.mul(s)); //  acc= acc + (s * delta);
279                                 }
280                                 acc = acc.mul(delT); //acc = acc*delT;
281                                 myPos = myPos.plus(myVel.mul(delT)).plus(acc.mul(delT/2); //  myPos = myPos + (myVel * delT) + (acc * delT)/2;
282                                 myVel = myVel.plus(acc)//   myVel = myVel + acc;
283                                 xyzPos[kc.gix] = myPos;
284                                 xyzVel[kc.gix] = myVel;
285 
286                             }
287                  */
288                 default -> throw new IllegalStateException();
289             };
290             var program = context.buildProgram(code);
291             kernel = program.getKernel("nbodygl");
292         }
293     }
294 
295     @Override
296     public void display() {
297         if (mode.equals(Mode.HAT) || mode.equals(Mode.OpenCL)) {
298             moveBodies();
299             glClearColor(0f, 0f, 0f, 0f);
300             glClear(GL_COLOR_BUFFER_BIT() | GL_DEPTH_BUFFER_BIT());
301             glEnable(GL_TEXTURE_2D()); // Annoyingly important,
302             glBindTexture(GL_TEXTURE_2D(), textureBuf.get(particle.idx));
303 
304             glPushMatrix1(() -> {
305                 glScalef(.01f, .01f, .01f);
306                 glColor3f(1f, 1f, 1f);
307                 glQuads(() -> {
308                     for (int bodyIdx = 0; bodyIdx < bodyCount; bodyIdx++) {
309                         var bodyf4 = universe.body(bodyIdx);//xyzPosFloatArr.get(bodyIdx);
310 
311                             /*
312                              * Textures are mapped to a quad by defining the vertices in
313                              * the order SW,NW,NE,SE
314                              &
315                              *   2--->3
316                              *   ^    |
317                              *   |    v
318                              *   1    4
319                              *
320                              * Here we are describing the 'texture plane' for the body.
321                              * Ideally we need to rotate this to point to the camera (see billboarding)
322                              */
323 
324                         glTexCoord2f(WEST, SOUTH);
325                         glVertex3f(bodyf4.x() + WEST + dx, bodyf4.y() + SOUTH + dy, bodyf4.z() + dz);
326                         glTexCoord2f(WEST, NORTH);
327                         glVertex3f(bodyf4.x() + WEST + dx, bodyf4.y() + NORTH + dy, bodyf4.z() + dz);
328                         glTexCoord2f(EAST, NORTH);
329                         glVertex3f(bodyf4.x() + EAST + dx, bodyf4.y() + NORTH + dy, bodyf4.z() + dz);
330                         glTexCoord2f(EAST, SOUTH);
331                         glVertex3f(bodyf4.x() + EAST + dx, bodyf4.y() + SOUTH + dy, bodyf4.z() + dz);
332 
333                     }
334                 });
335             });
336 
337             glDisable(GL_TEXTURE_2D()); // Annoyingly important .. took two days to work that out
338             //glUseProgram(0);
339             glMatrixMode(GL_MODELVIEW());
340             glPushMatrix1(() -> {
341                 glColor3f(0.0f, 1.0f, 0.0f);
342                 var font = glutBitmapTimesRoman24$segment();
343                 long elapsed = System.currentTimeMillis() - startTime;
344                 float secs = elapsed / 1000f;
345                 var FPS = "Mode: " + mode.toString() + " Bodies " + bodyCount + " FPS: " + ((frameCount / secs));
346                 glRasterPos2f(-.8f, .7f);
347                 for (int c : FPS.getBytes()) {
348                     glutBitmapCharacter(font, c);
349                 }
350             });
351             glutSwapBuffers();
352             frameCount++;
353         } else {
354             super.display();
355         }
356     }
357 
358 
359     @Override
360     protected void moveBodies() {
361         if (frameCount == 0) {
362             BufferState.of(universe).setState(BufferState.HOST_OWNED);
363         } else {
364             BufferState.of(universe).setState(BufferState.DEVICE_OWNED);
365         }
366         if (mode.equals(Mode.HAT)) {
367             float cmass = mass;
368             float cdelT = delT;
369             float cespSqr = espSqr;
370             Universe cuniverse = universe;
371             accelerator.compute((@Reflect Compute)
372                     cc -> nbodyCompute(cc, cuniverse, cmass, cdelT, cespSqr));
373         } else if (mode.equals(Mode.OpenCL4) || mode.equals(Mode.OpenCL)) {
374 
375             kernel.run(clWrapComputeContext, bodyCount, universe, mass, delT, espSqr);
376         } else {
377             super.moveBodies();
378         }
379     }
380 }
381 
382 
383