1 /*
  2  *  Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *  Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng/othervm --enable-native-access=ALL-UNNAMED TestMemoryAccessInstance
 27  */
 28 
 29 import jdk.incubator.foreign.MemoryAddress;
 30 import jdk.incubator.foreign.MemorySegment;
 31 
 32 import java.nio.ByteBuffer;
 33 import java.nio.ByteOrder;
 34 import java.util.function.Function;
 35 
 36 import jdk.incubator.foreign.ResourceScope;
 37 import jdk.incubator.foreign.ValueLayout;
 38 import org.testng.annotations.*;
 39 import static org.testng.Assert.*;
 40 
 41 public class TestMemoryAccessInstance {
 42 
 43     static class Accessor<T, X, L> {
 44 
 45         interface SegmentGetter<T, X, L> {
 46             X get(T buffer, L layout, long offset);
 47         }
 48 
 49         interface SegmentSetter<T, X, L> {
 50             void set(T buffer, L layout, long offset, X o);
 51         }
 52 
 53         interface BufferGetter<X> {
 54             X get(ByteBuffer segment, int offset);
 55         }
 56 
 57         interface BufferSetter<X> {
 58             void set(ByteBuffer buffer, int offset, X o);
 59         }
 60 
 61         final X value;
 62         final L layout;
 63         final Function<MemorySegment, T> transform;
 64         final SegmentGetter<T, X, L> segmentGetter;
 65         final SegmentSetter<T, X, L> segmentSetter;
 66         final BufferGetter<X> bufferGetter;
 67         final BufferSetter<X> bufferSetter;
 68 
 69         Accessor(Function<MemorySegment, T> transform, L layout, X value,
 70                  SegmentGetter<T, X, L> segmentGetter, SegmentSetter<T, X, L> segmentSetter,
 71                  BufferGetter<X> bufferGetter, BufferSetter<X> bufferSetter) {
 72             this.transform = transform;
 73             this.layout = layout;
 74             this.value = value;
 75             this.segmentGetter = segmentGetter;
 76             this.segmentSetter = segmentSetter;
 77             this.bufferGetter = bufferGetter;
 78             this.bufferSetter = bufferSetter;
 79         }
 80 
 81         void test() {
 82             try (ResourceScope scope = ResourceScope.newConfinedScope()) {
 83                 MemorySegment segment = MemorySegment.allocateNative(64, scope);
 84                 ByteBuffer buffer = segment.asByteBuffer();
 85                 T t = transform.apply(segment);
 86                 segmentSetter.set(t, layout, 4, value);
 87                 assertEquals(bufferGetter.get(buffer, 4), value);
 88                 bufferSetter.set(buffer, 4, value);
 89                 assertEquals(value, segmentGetter.get(t, layout, 4));
 90             }
 91         }
 92 
 93         static <L, X> Accessor<MemorySegment, X, L> ofSegment(L layout, X value,
 94                          SegmentGetter<MemorySegment, X, L> segmentGetter, SegmentSetter<MemorySegment, X, L> segmentSetter,
 95                          BufferGetter<X> bufferGetter, BufferSetter<X> bufferSetter) {
 96             return new Accessor<>(Function.identity(), layout, value, segmentGetter, segmentSetter, bufferGetter, bufferSetter);
 97         }
 98 
 99         static <L, X> Accessor<MemoryAddress, X, L> ofAddress(L layout, X value,
100                                                               SegmentGetter<MemoryAddress, X, L> segmentGetter, SegmentSetter<MemoryAddress, X, L> segmentSetter,
101                                                               BufferGetter<X> bufferGetter, BufferSetter<X> bufferSetter) {
102             return new Accessor<>(MemorySegment::address, layout, value, segmentGetter, segmentSetter, bufferGetter, bufferSetter);
103         }
104     }
105 
106     @Test(dataProvider = "segmentAccessors")
107     public void testSegmentAccess(String testName, Accessor<?, ?, ?> accessor) {
108         accessor.test();
109     }
110 
111     @Test(dataProvider = "addressAccessors")
112     public void testAddressAccess(String testName, Accessor<?, ?, ?> accessor) {
113         accessor.test();
114     }
115 
116     static final ByteOrder NE = ByteOrder.nativeOrder();
117 
118     @DataProvider(name = "segmentAccessors")
119     static Object[][] segmentAccessors() {
120         return new Object[][]{
121 
122                 {"byte", Accessor.ofSegment(ValueLayout.JAVA_BYTE, (byte) 42,
123                         MemorySegment::get, MemorySegment::set,
124                         ByteBuffer::get, ByteBuffer::put)
125                 },
126                 {"bool", Accessor.ofSegment(ValueLayout.JAVA_BOOLEAN, false,
127                         MemorySegment::get, MemorySegment::set,
128                         (bb, pos) -> bb.get(pos) != 0, (bb, pos, v) -> bb.put(pos, v ? (byte)1 : (byte)0))
129                 },
130                 {"char", Accessor.ofSegment(ValueLayout.JAVA_CHAR, (char) 42,
131                         MemorySegment::get, MemorySegment::set,
132                         (bb, pos) -> bb.order(NE).getChar(pos), (bb, pos, v) -> bb.order(NE).putChar(pos, v))
133                 },
134                 {"int", Accessor.ofSegment(ValueLayout.JAVA_INT, 42,
135                         MemorySegment::get, MemorySegment::set,
136                         (bb, pos) -> bb.order(NE).getInt(pos), (bb, pos, v) -> bb.order(NE).putInt(pos, v))
137                 },
138                 {"float", Accessor.ofSegment(ValueLayout.JAVA_FLOAT, 42f,
139                         MemorySegment::get, MemorySegment::set,
140                         (bb, pos) -> bb.order(NE).getFloat(pos), (bb, pos, v) -> bb.order(NE).putFloat(pos, v))
141                 },
142                 {"long", Accessor.ofSegment(ValueLayout.JAVA_LONG, 42L,
143                         MemorySegment::get, MemorySegment::set,
144                         (bb, pos) -> bb.order(NE).getLong(pos), (bb, pos, v) -> bb.order(NE).putLong(pos, v))
145                 },
146                 {"double", Accessor.ofSegment(ValueLayout.JAVA_DOUBLE, 42d,
147                         MemorySegment::get, MemorySegment::set,
148                         (bb, pos) -> bb.order(NE).getDouble(pos), (bb, pos, v) -> bb.order(NE).putDouble(pos, v))
149                 },
150                 { "address", Accessor.ofSegment(ValueLayout.ADDRESS, MemoryAddress.ofLong(42),
151                         MemorySegment::get, MemorySegment::set,
152                         (bb, pos) -> {
153                             ByteBuffer nb = bb.order(NE);
154                             long addr = ValueLayout.ADDRESS.byteSize() == 8 ?
155                                     nb.getLong(pos) : nb.getInt(pos);
156                             return MemoryAddress.ofLong(addr);
157                         },
158                         (bb, pos, v) -> {
159                             ByteBuffer nb = bb.order(NE);
160                             if (ValueLayout.ADDRESS.byteSize() == 8) {
161                                 nb.putLong(pos, v.toRawLongValue());
162                             } else {
163                                 nb.putInt(pos, (int)v.toRawLongValue());
164                             }
165                         })
166                 },
167 
168                 {"char/index", Accessor.ofSegment(ValueLayout.JAVA_CHAR, (char) 42,
169                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
170                         (bb, pos) -> bb.order(NE).getChar(pos * 2), (bb, pos, v) -> bb.order(NE).putChar(pos * 2, v))
171                 },
172                 {"int/index", Accessor.ofSegment(ValueLayout.JAVA_INT, 42,
173                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
174                         (bb, pos) -> bb.order(NE).getInt(pos * 4), (bb, pos, v) -> bb.order(NE).putInt(pos * 4, v))
175                 },
176                 {"float/index", Accessor.ofSegment(ValueLayout.JAVA_FLOAT, 42f,
177                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
178                         (bb, pos) -> bb.order(NE).getFloat(pos * 4), (bb, pos, v) -> bb.order(NE).putFloat(pos * 4, v))
179                 },
180                 {"long/index", Accessor.ofSegment(ValueLayout.JAVA_LONG, 42L,
181                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
182                         (bb, pos) -> bb.order(NE).getLong(pos * 8), (bb, pos, v) -> bb.order(NE).putLong(pos * 8, v))
183                 },
184                 {"double/index", Accessor.ofSegment(ValueLayout.JAVA_DOUBLE, 42d,
185                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
186                         (bb, pos) -> bb.order(NE).getDouble(pos * 8), (bb, pos, v) -> bb.order(NE).putDouble(pos * 8, v))
187                 },
188                 { "address/index", Accessor.ofSegment(ValueLayout.ADDRESS, MemoryAddress.ofLong(42),
189                         MemorySegment::getAtIndex, MemorySegment::setAtIndex,
190                         (bb, pos) -> {
191                             ByteBuffer nb = bb.order(NE);
192                             long addr = ValueLayout.ADDRESS.byteSize() == 8 ?
193                                     nb.getLong(pos * 8) : nb.getInt(pos * 4);
194                             return MemoryAddress.ofLong(addr);
195                         },
196                         (bb, pos, v) -> {
197                             ByteBuffer nb = bb.order(NE);
198                             if (ValueLayout.ADDRESS.byteSize() == 8) {
199                                 nb.putLong(pos * 8, v.toRawLongValue());
200                             } else {
201                                 nb.putInt(pos * 4, (int)v.toRawLongValue());
202                             }
203                         })
204                 },
205         };
206     }
207 
208     @DataProvider(name = "addressAccessors")
209     static Object[][] addressAccessors() {
210         return new Object[][]{
211 
212                 {"byte", Accessor.ofAddress(ValueLayout.JAVA_BYTE, (byte) 42,
213                         MemoryAddress::get, MemoryAddress::set,
214                         ByteBuffer::get, ByteBuffer::put)
215                 },
216                 {"bool", Accessor.ofAddress(ValueLayout.JAVA_BOOLEAN, false,
217                         MemoryAddress::get, MemoryAddress::set,
218                         (bb, pos) -> bb.get(pos) != 0, (bb, pos, v) -> bb.put(pos, v ? (byte)1 : (byte)0))
219                 },
220                 {"char", Accessor.ofAddress(ValueLayout.JAVA_CHAR, (char) 42,
221                         MemoryAddress::get, MemoryAddress::set,
222                         (bb, pos) -> bb.order(NE).getChar(pos), (bb, pos, v) -> bb.order(NE).putChar(pos, v))
223                 },
224                 {"int", Accessor.ofAddress(ValueLayout.JAVA_INT, 42,
225                         MemoryAddress::get, MemoryAddress::set,
226                         (bb, pos) -> bb.order(NE).getInt(pos), (bb, pos, v) -> bb.order(NE).putInt(pos, v))
227                 },
228                 {"float", Accessor.ofAddress(ValueLayout.JAVA_FLOAT, 42f,
229                         MemoryAddress::get, MemoryAddress::set,
230                         (bb, pos) -> bb.order(NE).getFloat(pos), (bb, pos, v) -> bb.order(NE).putFloat(pos, v))
231                 },
232                 {"long", Accessor.ofAddress(ValueLayout.JAVA_LONG, 42L,
233                         MemoryAddress::get, MemoryAddress::set,
234                         (bb, pos) -> bb.order(NE).getLong(pos), (bb, pos, v) -> bb.order(NE).putLong(pos, v))
235                 },
236                 {"double", Accessor.ofAddress(ValueLayout.JAVA_DOUBLE, 42d,
237                         MemoryAddress::get, MemoryAddress::set,
238                         (bb, pos) -> bb.order(NE).getDouble(pos), (bb, pos, v) -> bb.order(NE).putDouble(pos, v))
239                 },
240                 { "address", Accessor.ofAddress(ValueLayout.ADDRESS, MemoryAddress.ofLong(42),
241                         MemoryAddress::get, MemoryAddress::set,
242                         (bb, pos) -> {
243                             ByteBuffer nb = bb.order(NE);
244                             long addr = ValueLayout.ADDRESS.byteSize() == 8 ?
245                                     nb.getLong(pos) : nb.getInt(pos);
246                             return MemoryAddress.ofLong(addr);
247                         },
248                         (bb, pos, v) -> {
249                             ByteBuffer nb = bb.order(NE);
250                             if (ValueLayout.ADDRESS.byteSize() == 8) {
251                                 nb.putLong(pos, v.toRawLongValue());
252                             } else {
253                                 nb.putInt(pos, (int)v.toRawLongValue());
254                             }
255                         })
256                 },
257                 {"char/index", Accessor.ofAddress(ValueLayout.JAVA_CHAR, (char) 42,
258                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
259                         (bb, pos) -> bb.order(NE).getChar(pos * 2), (bb, pos, v) -> bb.order(NE).putChar(pos * 2, v))
260                 },
261                 {"int/index", Accessor.ofAddress(ValueLayout.JAVA_INT, 42,
262                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
263                         (bb, pos) -> bb.order(NE).getInt(pos * 4), (bb, pos, v) -> bb.order(NE).putInt(pos * 4, v))
264                 },
265                 {"float/index", Accessor.ofAddress(ValueLayout.JAVA_FLOAT, 42f,
266                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
267                         (bb, pos) -> bb.order(NE).getFloat(pos * 4), (bb, pos, v) -> bb.order(NE).putFloat(pos * 4, v))
268                 },
269                 {"long/index", Accessor.ofAddress(ValueLayout.JAVA_LONG, 42L,
270                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
271                         (bb, pos) -> bb.order(NE).getLong(pos * 8), (bb, pos, v) -> bb.order(NE).putLong(pos * 8, v))
272                 },
273                 {"double/index", Accessor.ofAddress(ValueLayout.JAVA_DOUBLE, 42d,
274                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
275                         (bb, pos) -> bb.order(NE).getDouble(pos * 8), (bb, pos, v) -> bb.order(NE).putDouble(pos * 8, v))
276                 },
277                 { "address/index", Accessor.ofAddress(ValueLayout.ADDRESS, MemoryAddress.ofLong(42),
278                         MemoryAddress::getAtIndex, MemoryAddress::setAtIndex,
279                         (bb, pos) -> {
280                             ByteBuffer nb = bb.order(NE);
281                             long addr = ValueLayout.ADDRESS.byteSize() == 8 ?
282                                     nb.getLong(pos * 8) : nb.getInt(pos * 4);
283                             return MemoryAddress.ofLong(addr);
284                         },
285                         (bb, pos, v) -> {
286                             ByteBuffer nb = bb.order(NE);
287                             if (ValueLayout.ADDRESS.byteSize() == 8) {
288                                 nb.putLong(pos * 8, v.toRawLongValue());
289                             } else {
290                                 nb.putInt(pos * 4, (int)v.toRawLongValue());
291                             }
292                         })
293                 }
294         };
295     }
296 }