1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package wrap;
 26 
 27 import java.lang.foreign.AddressLayout;
 28 import java.lang.foreign.Arena;
 29 import java.lang.foreign.MemorySegment;
 30 import java.lang.foreign.ValueLayout;
 31 
 32 import static java.lang.foreign.ValueLayout.ADDRESS;
 33 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 34 import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
 35 import static java.lang.foreign.ValueLayout.JAVA_FLOAT;
 36 import static java.lang.foreign.ValueLayout.JAVA_INT;
 37 import static java.lang.foreign.ValueLayout.JAVA_LONG;
 38 import static java.lang.foreign.ValueLayout.JAVA_SHORT;
 39 
 40 public class Wrap {
 41     public interface Ptr{
 42         MemorySegment ptr();
 43         long sizeof();
 44     }
 45     public interface Arr extends Ptr{
 46         default long length(){
 47             return sizeof()/elementSize();
 48         }
 49         long elementSize();
 50     }
 51     public  record IntPtr(MemorySegment ptr) implements Ptr {
 52         public static IntPtr of(Arena arena, int value) {
 53             return new IntPtr(arena.allocateFrom(JAVA_INT, value));
 54         }
 55 
 56         public int set(int value) {
 57             ptr.set(JAVA_INT, 0, value);
 58             return value;
 59         }
 60 
 61         public int get() {
 62             return ptr.get(JAVA_INT, 0);
 63         }
 64 
 65         @Override public long sizeof(){
 66             return JAVA_INT.byteSize();
 67         }
 68     }
 69 
 70     public  record LongPtr(MemorySegment ptr)  implements Ptr{
 71         public static LongPtr of(Arena arena, long value) {
 72             return new LongPtr(arena.allocateFrom(JAVA_LONG, value));
 73         }
 74 
 75         public long set(long value) {
 76             ptr.set(JAVA_LONG, 0, value);
 77             return value;
 78         }
 79 
 80         public long get() {
 81             return ptr.get(JAVA_LONG, 0);
 82         }
 83 
 84         @Override
 85         public long sizeof(){
 86             return JAVA_LONG.byteSize();
 87         }
 88     }
 89     public  record DoublePtr(MemorySegment ptr)  implements Ptr{
 90         public static DoublePtr of(Arena arena, double value) {
 91             return new DoublePtr(arena.allocateFrom(JAVA_DOUBLE, value));
 92         }
 93 
 94         public double set(double value) {
 95             ptr.set(JAVA_DOUBLE, 0, value);
 96             return value;
 97         }
 98 
 99         public double get() {
100             return ptr.get(JAVA_DOUBLE, 0);
101         }
102 
103         @Override
104         public long sizeof(){
105             return JAVA_DOUBLE.byteSize();
106         }
107     }
108 
109     public  record PtrArr(MemorySegment ptr)  implements Arr{
110         public static PtrArr of(Arena arena, int size) {
111             return new PtrArr(arena.allocate(ADDRESS, size));
112         }
113         public static PtrArr of(Arena arena, MemorySegment ...memorySegments) {
114             var ptrArray=  new PtrArr(arena.allocate(ADDRESS, memorySegments.length));
115             for (int i = 0; i < memorySegments.length; i++) {
116                 ptrArray.set(i, memorySegments[i]);
117             }
118             return ptrArray;
119         }
120 
121         public static PtrArr of(Arena arena, Ptr ...ptrs) {
122             var ptrArray=  new PtrArr(arena.allocate(ADDRESS, ptrs.length));
123             for (int i = 0; i < ptrs.length; i++) {
124                 ptrArray.set(i, ptrs[i].ptr());
125             }
126             return ptrArray;
127         }
128         public static PtrArr of(Arena arena, String ...strings) {
129             var ptrArray=  new PtrArr(arena.allocate(ADDRESS, strings.length));
130             for (int i = 0; i < strings.length; i++) {
131                 ptrArray.set(i, CStrPtr.of(arena,strings[i]).ptr());
132             }
133             return ptrArray;
134         }
135 
136 
137         public MemorySegment set(int idx, MemorySegment value) {
138             ptr.set(AddressLayout.ADDRESS, idx* ADDRESS.byteSize(), value);
139             return value;
140         }
141 
142         public MemorySegment get(int idx) {
143             return ptr.get(AddressLayout.ADDRESS, idx* ADDRESS.byteSize());
144         }
145 
146         @Override
147         public long sizeof(){
148             return ptr.byteSize();
149         }
150         @Override public  long elementSize(){
151             return AddressLayout.ADDRESS.byteSize();
152         }
153     }
154 
155     public  record CStrPtr(MemorySegment ptr, int len)  implements Ptr{
156         public static CStrPtr of(Arena arena, int len) {
157             return new CStrPtr(arena.allocate(JAVA_BYTE, len), len);
158         }
159         public static CStrPtr of(Arena arena, String str) {
160             return new CStrPtr(arena.allocateFrom( str), str.length());
161         }
162         public static CStrPtr of( MemorySegment str) {
163             return new CStrPtr(str, (int)str.byteSize());
164         }
165 
166         public String get() {
167             return ptr.getString(0);
168         }
169         @Override
170         public long sizeof(){
171             return JAVA_BYTE.byteSize();
172         }
173 
174         @Override public String toString(){
175             return get();
176         }
177     }
178 
179     public record FloatPtr(MemorySegment ptr)  implements Ptr{
180         public static FloatPtr of(Arena arena, float value) {
181             return new FloatPtr(arena.allocateFrom(JAVA_FLOAT, value));
182         }
183 
184         public float set(float value) {
185             ptr.set(JAVA_FLOAT, 0, value);
186             return value;
187         }
188 
189         public float get() {
190             return ptr.get(JAVA_FLOAT, 0);
191         }
192 
193         @Override
194         public long sizeof(){
195             return JAVA_FLOAT.byteSize();
196         }
197     }
198     public record ShortPtr(MemorySegment ptr)  implements Ptr{
199         public static ShortPtr of(Arena arena, short value) {
200             return new ShortPtr(arena.allocateFrom(JAVA_SHORT, value));
201         }
202 
203         public short set(short value) {
204             ptr.set(JAVA_SHORT, 0, value);
205             return value;
206         }
207 
208         public short get() {
209             return ptr.get(JAVA_SHORT, 0);
210         }
211 
212         @Override
213         public long sizeof(){
214             return JAVA_SHORT.byteSize();
215         }
216     }
217 
218     public record FloatArr(MemorySegment ptr)  implements Arr{
219         public static FloatArr of(Arena arena, int length) {
220             return new FloatArr(arena.allocate(JAVA_FLOAT, length));
221         }
222         public static FloatArr of(Arena arena, float[] floats) {
223             return new FloatArr(arena.allocateFrom(JAVA_FLOAT, floats));
224         }
225 
226         public float set(int idx, float value) {
227             ptr.set(JAVA_FLOAT, idx*JAVA_FLOAT.byteSize(), value);
228             return value;
229         }
230 
231         public float get(int idx) {
232             return ptr.get(JAVA_FLOAT, JAVA_FLOAT.byteSize()*idx);
233         }
234 
235 
236         @Override public  long elementSize(){
237             return JAVA_FLOAT.byteSize();
238         }
239 
240         @Override
241         public long sizeof(){
242             return ptr.byteSize();
243         }
244     }
245 
246     public record IntArr(MemorySegment ptr)  implements Arr{
247         public static IntArr of(Arena arena, int length) {
248             return new IntArr(arena.allocate(JAVA_INT, length));
249         }
250         public static IntArr ofValues(Arena arena, int ...values ) {
251             return of(arena, values);
252         }
253         public static IntArr of(Arena arena, int[] floats) {
254             return new IntArr(arena.allocateFrom(JAVA_INT, floats));
255         }
256 
257         public int set(int idx, int value) {
258             ptr.set(JAVA_INT, idx*JAVA_INT.byteSize(), value);
259             return value;
260         }
261 
262         public int get(int idx) {
263             return ptr.get(JAVA_INT, JAVA_INT.byteSize()*idx);
264         }
265 
266         @Override public  long elementSize(){
267             return JAVA_INT.byteSize();
268         }
269 
270         @Override
271         public long sizeof(){
272             return ptr.byteSize();
273         }
274     }
275 
276     public record Float4Arr(MemorySegment ptr)  implements Arr{
277         public record float4(float x, float y, float z, float w ) {
278             public static float4 of(float x, float y, float z, float w) {
279                 return new float4(x, y, z, w);
280             }
281             static public  final float4 zero = new float4(0.f, 0.f, 0.f, 0.f);
282             public static float4 of() {
283                 return zero;
284             }
285             public float4 sub(float4 rhs){
286                 return of(x-rhs.x,y-rhs.y,z-rhs.z,w-rhs.w);
287             }
288             public float4 add(float4 rhs){
289                 return of(x+rhs.x,y+rhs.y,z+rhs.z,w+rhs.w);
290             }
291 
292             public float4 mul(float rhs) {
293                 return of(x*rhs,y*rhs,z*rhs,w*rhs);
294             }
295             public float4 mul(float4 rhs) {
296                 return of(x* rhs.x,y* rhs.y,z* rhs.z,w*rhs.w);
297             }
298         }
299 
300 
301         public static Float4Arr of(Arena arena, int length) {
302             return new Float4Arr(arena.allocate(JAVA_FLOAT, length*JAVA_FLOAT.byteSize()));
303         }
304         public static Float4Arr of(Arena arena, float[] floats) {
305             return new Float4Arr(arena.allocateFrom(JAVA_FLOAT, floats));
306         }
307 
308         public float4 get(int idx){
309             return float4.of(
310                     ptr.get(JAVA_FLOAT, elementSize()*idx),
311                     ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize()),
312                     ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize()*2),
313                     ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize()*3));
314 
315         }
316         public void set(int idx, float4 f4) {
317             ptr.set(JAVA_FLOAT, idx*elementSize(), f4.x());
318             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize(), f4.y());
319             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize()*2, f4.z());
320             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize()*3, f4.w());
321         }
322         public float setx(int idx, float value) {
323             ptr.set(JAVA_FLOAT, idx*elementSize(), value);
324             return value;
325         }
326         public float sety(int idx, float value) {
327             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize(), value);
328             return value;
329         }
330         public float setz(int idx, float value) {
331             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize()*2, value);
332             return value;
333         }
334         public float setw(int idx, float value) {
335             ptr.set(JAVA_FLOAT, idx*elementSize()+JAVA_FLOAT.byteSize()*3, value);
336             return value;
337         }
338 
339         public float getx(int idx) {
340             return ptr.get(JAVA_FLOAT, elementSize()*idx);
341         }
342         public float gety(int idx) {
343             return ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize());
344         }
345         public float getz(int idx) {
346             return ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize()*2);
347         }
348         public float getw(int idx) {
349             return ptr.get(JAVA_FLOAT, elementSize()*idx+JAVA_FLOAT.byteSize()*3);
350         }
351 
352         @Override public  long elementSize(){
353             return JAVA_FLOAT.byteSize()*4;
354         }
355 
356         @Override
357         public long sizeof(){
358             return ptr.byteSize();
359         }
360     }
361 
362     public static void dump(MemorySegment s, int bytes){
363         char[] chars = new char[16];
364         boolean end=false;
365         for (int i = 0; !end && i < bytes; i++) {
366             int signed = s.get(ValueLayout.JAVA_BYTE, i);
367             int unsigned =  ((signed<0)?signed+256:signed)&0xff;
368             chars[i%16] = (char)unsigned;
369 
370             System.out.printf("%02x ", unsigned);
371             if (unsigned == 0){
372                 end=true;
373             }
374             if (i>0 && i%16==0){
375                 System.out.print(" | ");
376                 for (int c=0; c<16; c++){
377                     if (chars[c]<32){
378                         System.out.print(switch (chars[c]){
379                             case '\0'->"\\0";
380                             case '\n'->"\\n";
381                             case '\r'->"\\r";
382                             default -> chars[c]+"";
383                         });
384 
385                     }else {
386                         System.out.print(chars[c]);
387                     }
388                 }
389                 System.out.println();
390             }
391         }
392     }
393 }