1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.test;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.NDRange;
 30 import hat.KernelContext;
 31 import hat.backend.Backend;
 32 import optkl.ifacemapper.BoundSchema;
 33 import optkl.ifacemapper.Buffer;
 34 import optkl.ifacemapper.Schema;
 35 import hat.test.exceptions.HATAssertionError;
 36 import hat.test.exceptions.HATExpectedPrecisionError;
 37 import jdk.incubator.code.Reflect;
 38 import hat.test.annotation.HatTest;
 39 import hat.test.exceptions.HATAsserts;
 40 
 41 import java.lang.invoke.MethodHandles;
 42 import java.util.Random;
 43 
 44 import static optkl.ifacemapper.MappableIface.RO;
 45 import static optkl.ifacemapper.MappableIface.RW;
 46 import static hat.test.TestNbody.Universe.*;
 47 
 48 public class TestNbody {
 49 
 50     public interface Universe extends Buffer {
 51         int length();
 52 
 53         interface Body extends Struct {
 54             float x();
 55             float y();
 56             float z();
 57             float vx();
 58             float vy();
 59             float vz();
 60             void x(float x);
 61             void y(float y);
 62             void z(float z);
 63             void vx(float vx);
 64             void vy(float vy);
 65             void vz(float vz);
 66         }
 67 
 68         Body body(long idx);
 69 
 70         Schema<Universe> schema = Schema.of(Universe.class, resultTable -> resultTable
 71                 .arrayLen("length").array("body", array -> array
 72                         .fields("x", "y", "z", "vx", "vy", "vz")
 73                 )
 74         );
 75         static Universe create(Accelerator accelerator, int length) {
 76             return BoundSchema.of(accelerator ,schema, length).allocate();
 77         }
 78     }
 79 
 80     @Reflect
 81     static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
 82         float accx = 0.0f;
 83         float accy = 0.0f;
 84         float accz = 0.0f;
 85         Body body = universe.body(kc.gix);
 86 
 87         for (int i = 0; i < universe.length(); i++) {
 88             Body otherBody = universe.body(i);
 89             float dx = otherBody.x() - body.x();
 90             float dy = otherBody.y() - body.y();
 91             float dz = otherBody.z() - body.z();
 92             float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
 93             float s = mass * invDist * invDist * invDist;
 94             accx = accx + (s * dx);
 95             accy = accy + (s * dy);
 96             accz = accz + (s * dz);
 97         }
 98         accx = accx * delT;
 99         accy = accy * delT;
100         accz = accz * delT;
101         body.x(body.x() + (body.vx() * delT) + accx * .5f * delT);
102         body.y(body.y() + (body.vy() * delT) + accy * .5f * delT);
103         body.z(body.z() + (body.vz() * delT) + accz * .5f * delT);
104         body.vx(body.vx() + accx);
105         body.vy(body.vy() + accy);
106         body.vz(body.vz() + accz);
107     }
108 
109     @Reflect
110     public static void nbodyCompute(@RO ComputeContext cc, @RW Universe universe, final float mass, final float delT, final float espSqr) {
111         cc.dispatchKernel(NDRange.of1D(universe.length()), kernelContext -> nbodyKernel(kernelContext, universe, mass, delT, espSqr));
112     }
113 
114     public static void computeSequential(Universe universe, float mass, float delT, float espSqr) {
115         float accx = 0.0f;
116         float accy = 0.0f;
117         float accz = 0.0f;
118         for (int j = 0; j < universe.length(); j++) {
119             Body body = universe.body(j);
120 
121             for (int i = 0; i < universe.length(); i++) {
122                 Body otherBody = universe.body(i);
123                 float dx = otherBody.x() - body.x();
124                 float dy = otherBody.y() - body.y();
125                 float dz = otherBody.z() - body.z();
126                 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
127                 float s = mass * invDist * invDist * invDist;
128                 accx = accx + (s * dx);
129                 accy = accy + (s * dy);
130                 accz = accz + (s * dz);
131             }
132             accx = accx * delT;
133             accy = accy * delT;
134             accz = accz * delT;
135             body.x(body.x() + (body.vx() * delT) + accx * .5f * delT);
136             body.y(body.y() + (body.vy() * delT) + accy * .5f * delT);
137             body.z(body.z() + (body.vz() * delT) + accz * .5f * delT);
138             body.vx(body.vx() + accx);
139             body.vy(body.vy() + accy);
140             body.vz(body.vz() + accz);
141         }
142     }
143 
144     @HatTest
145     @Reflect
146     public void testNbody() {
147         final int NUM_BODIES = 1024;
148         var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
149         Universe universe = create(accelerator, NUM_BODIES);
150         Universe universeSeq = create(accelerator, NUM_BODIES);
151         final float delT = .1f;
152         final float espSqr = 0.1f;
153         final float mass = .5f;
154 
155         Random random = new Random(71);
156         for (int bodyIdx = 0; bodyIdx < NUM_BODIES; bodyIdx++) {
157             Body b = universe.body(bodyIdx);
158 
159             final float theta = (float) (Math.random() * Math.PI * 2);
160             final float phi = (float) (Math.random() * Math.PI * 2);
161             final float radius = (float) (Math.random() * 100.f);
162 
163             // get random 3D coordinates in sphere
164             b.x((float) (radius * Math.cos(theta) * Math.sin(phi)));
165             b.y((float) (radius * Math.sin(theta) * Math.sin(phi)));
166             b.z((float) (radius * Math.cos(phi)));
167             b.vx(random.nextFloat(1));
168             b.vy(random.nextFloat(1));
169             b.vz(random.nextFloat(1));
170 
171             // Copy random values into the other body to check results
172             Body seqBody = universeSeq.body(bodyIdx);
173             seqBody.x(b.x());
174             seqBody.y(b.y());
175             seqBody.z(b.z());
176 
177             seqBody.vx(b.vx());
178             seqBody.vy(b.vy());
179             seqBody.vz(b.vz());
180 
181         }
182 
183         accelerator.compute(computeContext -> {
184             TestNbody.nbodyCompute(computeContext, universe, mass, delT, espSqr);
185         });
186 
187         computeSequential(universeSeq, espSqr, mass, espSqr);
188 
189         // Check results
190         float delta = 0.1f;
191         for (int i = 0; i < NUM_BODIES; i++) {
192             Body hatBody = universe.body(i);
193             Body seqBody = universeSeq.body(i);
194             IO.println(i);
195             try {
196                 HATAsserts.assertEquals(seqBody.x(), hatBody.x(), delta);
197                 HATAsserts.assertEquals(seqBody.y(), hatBody.y(), delta);
198                 HATAsserts.assertEquals(seqBody.z(), hatBody.z(), delta);
199                 HATAsserts.assertEquals(seqBody.vx(), hatBody.vx(), delta);
200                 HATAsserts.assertEquals(seqBody.vy(), hatBody.vy(), delta);
201                 HATAsserts.assertEquals(seqBody.vz(), hatBody.vz(), delta);
202             } catch (HATAssertionError hatAssertionError) {
203                 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
204             }
205         }
206     }
207 }