1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.test;
26
27 import hat.Accelerator;
28 import hat.ComputeContext;
29 import hat.NDRange;
30 import hat.KernelContext;
31 import hat.backend.Backend;
32 import optkl.ifacemapper.BoundSchema;
33 import optkl.ifacemapper.Buffer;
34 import optkl.ifacemapper.Schema;
35 import hat.test.exceptions.HATAssertionError;
36 import hat.test.exceptions.HATExpectedPrecisionError;
37 import jdk.incubator.code.Reflect;
38 import hat.test.annotation.HatTest;
39 import hat.test.exceptions.HATAsserts;
40
41 import java.lang.invoke.MethodHandles;
42 import java.util.Random;
43
44 import static optkl.ifacemapper.MappableIface.RO;
45 import static optkl.ifacemapper.MappableIface.RW;
46 import static hat.test.TestNbody.Universe.*;
47
48 public class TestNbody {
49
50 public interface Universe extends Buffer {
51 int length();
52
53 interface Body extends Struct {
54 float x();
55 float y();
56 float z();
57 float vx();
58 float vy();
59 float vz();
60 void x(float x);
61 void y(float y);
62 void z(float z);
63 void vx(float vx);
64 void vy(float vy);
65 void vz(float vz);
66 }
67
68 Body body(long idx);
69
70 Schema<Universe> schema = Schema.of(Universe.class, resultTable -> resultTable
71 .arrayLen("length").array("body", array -> array
72 .fields("x", "y", "z", "vx", "vy", "vz")
73 )
74 );
75 static Universe create(Accelerator accelerator, int length) {
76 return BoundSchema.of(accelerator ,schema, length).allocate();
77 }
78 }
79
80 @Reflect
81 static public void nbodyKernel(@RO KernelContext kc, @RW Universe universe, float mass, float delT, float espSqr) {
82 float accx = 0.0f;
83 float accy = 0.0f;
84 float accz = 0.0f;
85 Body body = universe.body(kc.gix);
86
87 for (int i = 0; i < universe.length(); i++) {
88 Body otherBody = universe.body(i);
89 float dx = otherBody.x() - body.x();
90 float dy = otherBody.y() - body.y();
91 float dz = otherBody.z() - body.z();
92 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
93 float s = mass * invDist * invDist * invDist;
94 accx = accx + (s * dx);
95 accy = accy + (s * dy);
96 accz = accz + (s * dz);
97 }
98 accx = accx * delT;
99 accy = accy * delT;
100 accz = accz * delT;
101 body.x(body.x() + (body.vx() * delT) + accx * .5f * delT);
102 body.y(body.y() + (body.vy() * delT) + accy * .5f * delT);
103 body.z(body.z() + (body.vz() * delT) + accz * .5f * delT);
104 body.vx(body.vx() + accx);
105 body.vy(body.vy() + accy);
106 body.vz(body.vz() + accz);
107 }
108
109 @Reflect
110 public static void nbodyCompute(@RO ComputeContext cc, @RW Universe universe, final float mass, final float delT, final float espSqr) {
111 cc.dispatchKernel(NDRange.of1D(universe.length()), kernelContext -> nbodyKernel(kernelContext, universe, mass, delT, espSqr));
112 }
113
114 public static void computeSequential(Universe universe, float mass, float delT, float espSqr) {
115 float accx = 0.0f;
116 float accy = 0.0f;
117 float accz = 0.0f;
118 for (int j = 0; j < universe.length(); j++) {
119 Body body = universe.body(j);
120
121 for (int i = 0; i < universe.length(); i++) {
122 Body otherBody = universe.body(i);
123 float dx = otherBody.x() - body.x();
124 float dy = otherBody.y() - body.y();
125 float dz = otherBody.z() - body.z();
126 float invDist = (float) (1.0f / Math.sqrt(((dx * dx) + (dy * dy) + (dz * dz) + espSqr)));
127 float s = mass * invDist * invDist * invDist;
128 accx = accx + (s * dx);
129 accy = accy + (s * dy);
130 accz = accz + (s * dz);
131 }
132 accx = accx * delT;
133 accy = accy * delT;
134 accz = accz * delT;
135 body.x(body.x() + (body.vx() * delT) + accx * .5f * delT);
136 body.y(body.y() + (body.vy() * delT) + accy * .5f * delT);
137 body.z(body.z() + (body.vz() * delT) + accz * .5f * delT);
138 body.vx(body.vx() + accx);
139 body.vy(body.vy() + accy);
140 body.vz(body.vz() + accz);
141 }
142 }
143
144 @HatTest
145 @Reflect
146 public void testNbody() {
147 final int NUM_BODIES = 1024;
148 var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST);
149 Universe universe = create(accelerator, NUM_BODIES);
150 Universe universeSeq = create(accelerator, NUM_BODIES);
151 final float delT = .1f;
152 final float espSqr = 0.1f;
153 final float mass = .5f;
154
155 Random random = new Random(71);
156 for (int bodyIdx = 0; bodyIdx < NUM_BODIES; bodyIdx++) {
157 Body b = universe.body(bodyIdx);
158
159 final float theta = (float) (Math.random() * Math.PI * 2);
160 final float phi = (float) (Math.random() * Math.PI * 2);
161 final float radius = (float) (Math.random() * 100.f);
162
163 // get random 3D coordinates in sphere
164 b.x((float) (radius * Math.cos(theta) * Math.sin(phi)));
165 b.y((float) (radius * Math.sin(theta) * Math.sin(phi)));
166 b.z((float) (radius * Math.cos(phi)));
167 b.vx(random.nextFloat(1));
168 b.vy(random.nextFloat(1));
169 b.vz(random.nextFloat(1));
170
171 // Copy random values into the other body to check results
172 Body seqBody = universeSeq.body(bodyIdx);
173 seqBody.x(b.x());
174 seqBody.y(b.y());
175 seqBody.z(b.z());
176
177 seqBody.vx(b.vx());
178 seqBody.vy(b.vy());
179 seqBody.vz(b.vz());
180
181 }
182
183 accelerator.compute(computeContext -> {
184 TestNbody.nbodyCompute(computeContext, universe, mass, delT, espSqr);
185 });
186
187 computeSequential(universeSeq, espSqr, mass, espSqr);
188
189 // Check results
190 float delta = 0.1f;
191 for (int i = 0; i < NUM_BODIES; i++) {
192 Body hatBody = universe.body(i);
193 Body seqBody = universeSeq.body(i);
194 IO.println(i);
195 try {
196 HATAsserts.assertEquals(seqBody.x(), hatBody.x(), delta);
197 HATAsserts.assertEquals(seqBody.y(), hatBody.y(), delta);
198 HATAsserts.assertEquals(seqBody.z(), hatBody.z(), delta);
199 HATAsserts.assertEquals(seqBody.vx(), hatBody.vx(), delta);
200 HATAsserts.assertEquals(seqBody.vy(), hatBody.vy(), delta);
201 HATAsserts.assertEquals(seqBody.vz(), hatBody.vz(), delta);
202 } catch (HATAssertionError hatAssertionError) {
203 throw new HATExpectedPrecisionError(hatAssertionError.getMessage());
204 }
205 }
206 }
207 }