1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @run testng TestMismatch
 28  */
 29 
 30 import java.lang.foreign.Arena;
 31 import java.util.ArrayList;
 32 import java.util.List;
 33 import java.util.concurrent.atomic.AtomicReference;
 34 
 35 import java.lang.foreign.MemorySegment;
 36 import java.lang.foreign.ValueLayout;
 37 import java.util.function.IntFunction;
 38 import java.util.stream.Stream;
 39 
 40 import org.testng.annotations.DataProvider;
 41 import org.testng.annotations.Test;
 42 import static java.lang.System.out;
 43 import static org.testng.Assert.assertEquals;
 44 import static org.testng.Assert.assertThrows;
 45 
 46 public class TestMismatch {
 47 
 48     // stores a increasing sequence of values into the memory of the given segment
 49     static MemorySegment initializeSegment(MemorySegment segment) {
 50         for (int i = 0 ; i < segment.byteSize() ; i++) {
 51             segment.set(ValueLayout.JAVA_BYTE, i, (byte)i);
 52         }
 53         return segment;
 54     }
 55 
 56     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 57     public void testNegativeSrcFromOffset(MemorySegment s1, MemorySegment s2) {
 58         MemorySegment.mismatch(s1, -1, 0, s2, 0, 0);
 59     }
 60 
 61     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 62     public void testNegativeDstFromOffset(MemorySegment s1, MemorySegment s2) {
 63         MemorySegment.mismatch(s1, 0, 0, s2, -1, 0);
 64     }
 65 
 66     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 67     public void testNegativeSrcToOffset(MemorySegment s1, MemorySegment s2) {
 68         MemorySegment.mismatch(s1, 0, -1, s2, 0, 0);
 69     }
 70 
 71     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 72     public void testNegativeDstToOffset(MemorySegment s1, MemorySegment s2) {
 73         MemorySegment.mismatch(s1, 0, 0, s2, 0, -1);
 74     }
 75 
 76     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 77     public void testNegativeSrcLength(MemorySegment s1, MemorySegment s2) {
 78         MemorySegment.mismatch(s1, 3, 2, s2, 0, 0);
 79     }
 80 
 81     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 82     public void testNegativeDstLength(MemorySegment s1, MemorySegment s2) {
 83         MemorySegment.mismatch(s1, 0, 0, s2, 3, 2);
 84     }
 85 
 86     @Test(dataProvider = "slices")
 87     public void testSameValues(MemorySegment ss1, MemorySegment ss2) {
 88         out.format("testSameValues s1:%s, s2:%s\n", ss1, ss2);
 89         MemorySegment s1 = initializeSegment(ss1);
 90         MemorySegment s2 = initializeSegment(ss2);
 91 
 92         if (s1.byteSize() == s2.byteSize()) {
 93             assertEquals(s1.mismatch(s2), -1);  // identical
 94             assertEquals(s2.mismatch(s1), -1);
 95         } else if (s1.byteSize() > s2.byteSize()) {
 96             assertEquals(s1.mismatch(s2), s2.byteSize());  // proper prefix
 97             assertEquals(s2.mismatch(s1), s2.byteSize());
 98         } else {
 99             assert s1.byteSize() < s2.byteSize();
100             assertEquals(s1.mismatch(s2), s1.byteSize());  // proper prefix
101             assertEquals(s2.mismatch(s1), s1.byteSize());
102         }
103     }
104 
105     @Test(dataProvider = "slicesStatic")
106     public void testSameValuesStatic(SliceOffsetAndSize ss1, SliceOffsetAndSize ss2) {
107         out.format("testSameValuesStatic s1:%s, s2:%s\n", ss1, ss2);
108         MemorySegment s1 = initializeSegment(ss1.toSlice());
109         MemorySegment s2 = initializeSegment(ss2.toSlice());
110 
111         for (long i = ss2.offset ; i < ss2.size ; i++) {
112             long bytes = i - ss2.offset;
113             long expected = (bytes == ss1.size) ?
114                     -1 : Long.min(ss1.size, bytes);
115             assertEquals(MemorySegment.mismatch(ss1.segment, ss1.offset, ss1.endOffset(), ss2.segment, ss2.offset, i), expected);
116         }
117         for (long i = ss1.offset ; i < ss1.size ; i++) {
118             long bytes = i - ss1.offset;
119             long expected = (bytes == ss2.size) ?
120                     -1 : Long.min(ss2.size, bytes);
121             assertEquals(MemorySegment.mismatch(ss2.segment, ss2.offset, ss2.endOffset(), ss1.segment, ss1.offset, i), expected);
122         }
123     }
124 
125     @Test(dataProvider = "slices")
126     public void testDifferentValues(MemorySegment s1, MemorySegment s2) {
127         out.format("testDifferentValues s1:%s, s2:%s\n", s1, s2);
128         s1 = initializeSegment(s1);
129         s2 = initializeSegment(s2);
130 
131         for (long i = s2.byteSize() -1 ; i >= 0; i--) {
132             long expectedMismatchOffset = i;
133             s2.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
134 
135             if (s1.byteSize() == s2.byteSize()) {
136                 assertEquals(s1.mismatch(s2), expectedMismatchOffset);
137                 assertEquals(s2.mismatch(s1), expectedMismatchOffset);
138             } else if (s1.byteSize() > s2.byteSize()) {
139                 assertEquals(s1.mismatch(s2), expectedMismatchOffset);
140                 assertEquals(s2.mismatch(s1), expectedMismatchOffset);
141             } else {
142                 assert s1.byteSize() < s2.byteSize();
143                 var off = Math.min(s1.byteSize(), expectedMismatchOffset);
144                 assertEquals(s1.mismatch(s2), off);  // proper prefix
145                 assertEquals(s2.mismatch(s1), off);
146             }
147         }
148     }
149 
150     @Test(dataProvider = "slicesStatic")
151     public void testDifferentValuesStatic(SliceOffsetAndSize ss1, SliceOffsetAndSize ss2) {
152         out.format("testDifferentValues s1:%s, s2:%s\n", ss1, ss2);
153 
154         for (long i = ss2.size - 1 ; i >= 0; i--) {
155             if (i >= ss1.size) continue;
156             initializeSegment(ss1.toSlice());
157             initializeSegment(ss2.toSlice());
158             long expectedMismatchOffset = i;
159             ss2.toSlice().set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
160 
161             for (long j = expectedMismatchOffset + 1 ; j < ss2.size ; j++) {
162                 assertEquals(MemorySegment.mismatch(ss1.segment, ss1.offset, ss1.endOffset(), ss2.segment, ss2.offset, j + ss2.offset), expectedMismatchOffset);
163             }
164             for (long j = expectedMismatchOffset + 1 ; j < ss1.size ; j++) {
165                 assertEquals(MemorySegment.mismatch(ss2.segment, ss2.offset, ss2.endOffset(), ss1.segment, ss1.offset, j + ss1.offset), expectedMismatchOffset);
166             }
167         }
168     }
169 
170     @Test
171     public void testEmpty() {
172         var s1 = MemorySegment.ofArray(new byte[0]);
173         assertEquals(s1.mismatch(s1), -1);
174         try (Arena arena = Arena.ofConfined()) {
175             var nativeSegment = arena.allocate(4, 4);;
176             var s2 = nativeSegment.asSlice(0, 0);
177             assertEquals(s1.mismatch(s2), -1);
178             assertEquals(s2.mismatch(s1), -1);
179         }
180     }
181 
182     @Test
183     public void testLarge() {
184         // skip if not on 64 bits
185         if (ValueLayout.ADDRESS.byteSize() > 32) {
186             try (Arena arena = Arena.ofConfined()) {
187                 var s1 = arena.allocate((long) Integer.MAX_VALUE + 10L, 8);;
188                 var s2 = arena.allocate((long) Integer.MAX_VALUE + 10L, 8);;
189                 assertEquals(s1.mismatch(s1), -1);
190                 assertEquals(s1.mismatch(s2), -1);
191                 assertEquals(s2.mismatch(s1), -1);
192 
193                 testLargeAcrossMaxBoundary(s1, s2);
194 
195                 testLargeMismatchAcrossMaxBoundary(s1, s2);
196             }
197         }
198     }
199 
200     private void testLargeAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
201         for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
202             var s3 = s1.asSlice(0, i);
203             var s4 = s2.asSlice(0, i);
204             // instance
205             assertEquals(s3.mismatch(s3), -1);
206             assertEquals(s3.mismatch(s4), -1);
207             assertEquals(s4.mismatch(s3), -1);
208             // static
209             assertEquals(MemorySegment.mismatch(s1, 0, s1.byteSize(), s1, 0, i), -1);
210             assertEquals(MemorySegment.mismatch(s2, 0, s1.byteSize(), s1, 0, i), -1);
211             assertEquals(MemorySegment.mismatch(s1, 0, s1.byteSize(), s2, 0, i), -1);
212         }
213     }
214 
215     private void testLargeMismatchAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
216         for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
217             s2.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
218             long expectedMismatchOffset = i;
219             assertEquals(s1.mismatch(s2), expectedMismatchOffset);
220             assertEquals(s2.mismatch(s1), expectedMismatchOffset);
221         }
222     }
223 
224     static final Class<IllegalStateException> ISE = IllegalStateException.class;
225     static final Class<UnsupportedOperationException> UOE = UnsupportedOperationException.class;
226 
227     @Test
228     public void testClosed() {
229         MemorySegment s1, s2;
230         try (Arena arena = Arena.ofConfined()) {
231             s1 = arena.allocate(4, 1);
232             s2 = arena.allocate(4, 1);;
233         }
234         assertThrows(ISE, () -> s1.mismatch(s1));
235         assertThrows(ISE, () -> s1.mismatch(s2));
236         assertThrows(ISE, () -> s2.mismatch(s1));
237     }
238 
239     @Test
240     public void testThreadAccess() throws Exception {
241         try (Arena arena = Arena.ofConfined()) {
242             var segment = arena.allocate(4, 1);;
243             {
244                 AtomicReference<RuntimeException> exception = new AtomicReference<>();
245                 Runnable action = () -> {
246                     try {
247                         MemorySegment.ofArray(new byte[4]).mismatch(segment);
248                     } catch (RuntimeException e) {
249                         exception.set(e);
250                     }
251                 };
252                 Thread thread = new Thread(action);
253                 thread.start();
254                 thread.join();
255 
256                 RuntimeException e = exception.get();
257                 if (!(e instanceof WrongThreadException)) {
258                     throw e;
259                 }
260             }
261             {
262                 AtomicReference<RuntimeException> exception = new AtomicReference<>();
263                 Runnable action = () -> {
264                     try {
265                         segment.mismatch(MemorySegment.ofArray(new byte[4]));
266                     } catch (RuntimeException e) {
267                         exception.set(e);
268                     }
269                 };
270                 Thread thread = new Thread(action);
271                 thread.start();
272                 thread.join();
273 
274                 RuntimeException e = exception.get();
275                 if (!(e instanceof WrongThreadException)) {
276                     throw e;
277                 }
278             }
279         }
280     }
281 
282     enum SegmentKind {
283         NATIVE(i -> Arena.ofAuto().allocate(i, 1)),
284         ARRAY(i -> MemorySegment.ofArray(new byte[i]));
285 
286         final IntFunction<MemorySegment> segmentFactory;
287 
288         SegmentKind(IntFunction<MemorySegment> segmentFactory) {
289             this.segmentFactory = segmentFactory;
290         }
291 
292         MemorySegment makeSegment(int elems) {
293             return segmentFactory.apply(elems);
294         }
295     }
296 
297     record SliceOffsetAndSize(MemorySegment segment, long offset, long size) {
298         MemorySegment toSlice() {
299             return segment.asSlice(offset, size);
300         }
301         long endOffset() {
302             return offset + size;
303         }
304     };
305 
306     @DataProvider(name = "slicesStatic")
307     static Object[][] slicesStatic() {
308         int[] sizes = { 16, 8, 1 };
309         List<SliceOffsetAndSize> aSliceOffsetAndSizes = new ArrayList<>();
310         List<SliceOffsetAndSize> bSliceOffsetAndSizes = new ArrayList<>();
311         for (List<SliceOffsetAndSize> slices : List.of(aSliceOffsetAndSizes, bSliceOffsetAndSizes)) {
312             for (SegmentKind kind : SegmentKind.values()) {
313                 MemorySegment segment = kind.makeSegment(16);
314                 //compute all slices
315                 for (int size : sizes) {
316                     for (int index = 0 ; index < 16 ; index += size) {
317                         slices.add(new SliceOffsetAndSize(segment, index, size));
318                     }
319                 }
320             }
321         }
322         assert aSliceOffsetAndSizes.size() == bSliceOffsetAndSizes.size();
323         Object[][] sliceArray = new Object[aSliceOffsetAndSizes.size() * bSliceOffsetAndSizes.size()][];
324         for (int i = 0 ; i < aSliceOffsetAndSizes.size() ; i++) {
325             for (int j = 0 ; j < bSliceOffsetAndSizes.size() ; j++) {
326                 sliceArray[i * aSliceOffsetAndSizes.size() + j] = new Object[] { aSliceOffsetAndSizes.get(i), bSliceOffsetAndSizes.get(j) };
327             }
328         }
329         return sliceArray;
330     }
331 
332     @DataProvider(name = "slices")
333     static Object[][] slices() {
334         Object[][] slicesStatic = slicesStatic();
335         return Stream.of(slicesStatic)
336                 .map(arr -> new Object[]{
337                         ((SliceOffsetAndSize) arr[0]).toSlice(),
338                         ((SliceOffsetAndSize) arr[1]).toSlice()
339                 }).toArray(Object[][]::new);
340     }
341 }