1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.KernelContext;
 30 import hat.backend.ffi.OpenCLBackend;
 31 import static hat.backend.ffi.Config.*;
 32 import hat.ifacemapper.BoundSchema;
 33 import hat.ifacemapper.Schema;
 34 import hat.buffer.Buffer;
 35 
 36 import java.lang.foreign.GroupLayout;
 37 import java.lang.foreign.MemoryLayout;
 38 import java.lang.invoke.MethodHandles;
 39 import jdk.incubator.code.CodeReflection;
 40 import java.util.Random;
 41 
 42 import static java.lang.foreign.ValueLayout.JAVA_INT;
 43 
 44 public class Mesh {
 45     public interface MeshData extends Buffer {
 46         interface Point3D extends Struct {
 47             int x();
 48 
 49             void x(int x);
 50 
 51             int y();
 52 
 53             void y(int y);
 54 
 55             int z();
 56 
 57             void z(int z);
 58 
 59         }
 60 
 61        int points();
 62 
 63       //  void points(int points);
 64 
 65         Point3D point(long idx);
 66 
 67         interface Vertex3D extends Struct {
 68             int from();
 69 
 70             void from(int id);
 71 
 72             int to();
 73 
 74             void to(int id);
 75 
 76         }
 77 
 78         int vertices();
 79 
 80        // void vertices(int vertices);
 81 
 82         Vertex3D vertex(long idx);
 83 
 84 
 85          GroupLayout LAYOUT = MemoryLayout.structLayout(
 86                 JAVA_INT.withName("points"),
 87                 MemoryLayout.sequenceLayout(100,
 88                 MemoryLayout.structLayout(
 89                 JAVA_INT.withName("x"),
 90                 JAVA_INT.withName("y"),
 91                 JAVA_INT.withName("z")
 92                 ).withName(Point3D.class.getSimpleName())
 93                 ).withName("point"),
 94                     JAVA_INT.withName("vertices"),
 95                             MemoryLayout.sequenceLayout(10,
 96                             MemoryLayout.structLayout(
 97                             JAVA_INT.withName("from"),
 98                             JAVA_INT.withName("to")
 99                             ).withName(Vertex3D.class.getSimpleName())
100                 ).withName("vertex")
101             ).withName(MeshData.class.getSimpleName());
102 
103         static GroupLayout getLayout() {
104             return LAYOUT;
105         }
106 
107 
108         Schema<MeshData> schema = Schema.of(MeshData.class, cascade -> cascade
109                 .arrayLen("points").array("point", p -> p.fields("x", "y", "z"))
110                 .arrayLen("vertices").array("vertex", v -> v.fields("from", "to"))
111         );
112         static  MeshData create(Accelerator accelerator) {
113             return schema.allocate(accelerator,100,10);
114         }
115     }
116 
117     public static class Compute {
118         @CodeReflection
119         public static void initPoints(KernelContext kc, MeshData mesh) {
120             if (kc.x < kc.maxX) {
121                 MeshData.Point3D point = mesh.point(kc.x);
122                 point.x(kc.x);
123                 point.y(0);
124                 point.z(0);
125             }
126         }
127 
128         @CodeReflection
129         public static void buildMesh(ComputeContext cc, MeshData meshData) {
130             cc.dispatchKernel(meshData.points(),
131                     kc -> initPoints(kc, meshData)
132             );
133 
134         }
135     }
136 
137 
138     public static void main(String[] args) {
139         Accelerator accelerator = new Accelerator(MethodHandles.lookup()
140                 ,new OpenCLBackend(of(PROFILE(),  TRACE())));
141                 //,new DebugBackend(
142                 //DebugBackend.HowToRunCompute.REFLECT,
143                 //DebugBackend.HowToRunKernel.BABYLON_INTERPRETER));
144       //  MeshData.schema.toText(t -> System.out.print(t));
145 
146         var boundSchema = new BoundSchema<>(MeshData.schema, 100, 10);
147         var meshDataNew = boundSchema.allocate(accelerator.lookup,accelerator);
148         var meshDataOld = MeshData.create(accelerator);
149 
150         String layoutNew = Buffer.getLayout(meshDataNew).toString();
151         String layoutOld = Buffer.getLayout(meshDataOld).toString();
152         if (layoutOld.equals(layoutNew)) {
153             MeshData meshData = MeshData.create(accelerator);
154             Random random = new Random(System.currentTimeMillis());
155             for (int p = 0; p < meshData.points(); p++) {
156                 var point3D = meshData.point(p);
157                 point3D.x(random.nextInt(100));
158                 point3D.y(random.nextInt(100));
159                 point3D.z(random.nextInt(100));
160             }
161             for (int v = 0; v < meshData.vertices(); v++) {
162                 var vertex3D = meshData.vertex(v);
163                 vertex3D.from(random.nextInt(meshData.points()));
164                 vertex3D.to(random.nextInt(meshData.points()));
165             }
166 
167             accelerator.compute(cc -> Compute.buildMesh(cc, meshData));
168         }else{
169             System.out.println("layouts differ");
170         }
171 
172     }
173 }