1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng TestMismatch
 27  */
 28 
 29 import java.lang.foreign.Arena;
 30 import java.util.ArrayList;
 31 import java.util.List;
 32 import java.util.concurrent.atomic.AtomicReference;
 33 
 34 import java.lang.foreign.MemorySegment;
 35 import java.lang.foreign.ValueLayout;
 36 import java.util.function.IntFunction;
 37 import java.util.stream.Stream;
 38 
 39 import org.testng.annotations.DataProvider;
 40 import org.testng.annotations.Test;
 41 import static java.lang.System.out;
 42 import static org.testng.Assert.assertEquals;
 43 import static org.testng.Assert.assertThrows;
 44 
 45 public class TestMismatch {
 46 
 47     // stores a increasing sequence of values into the memory of the given segment
 48     static MemorySegment initializeSegment(MemorySegment segment) {
 49         for (int i = 0 ; i < segment.byteSize() ; i++) {
 50             segment.set(ValueLayout.JAVA_BYTE, i, (byte)i);
 51         }
 52         return segment;
 53     }
 54 
 55     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 56     public void testNegativeSrcFromOffset(MemorySegment s1, MemorySegment s2) {
 57         MemorySegment.mismatch(s1, -1, 0, s2, 0, 0);
 58     }
 59 
 60     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 61     public void testNegativeDstFromOffset(MemorySegment s1, MemorySegment s2) {
 62         MemorySegment.mismatch(s1, 0, 0, s2, -1, 0);
 63     }
 64 
 65     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 66     public void testNegativeSrcToOffset(MemorySegment s1, MemorySegment s2) {
 67         MemorySegment.mismatch(s1, 0, -1, s2, 0, 0);
 68     }
 69 
 70     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 71     public void testNegativeDstToOffset(MemorySegment s1, MemorySegment s2) {
 72         MemorySegment.mismatch(s1, 0, 0, s2, 0, -1);
 73     }
 74 
 75     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 76     public void testNegativeSrcLength(MemorySegment s1, MemorySegment s2) {
 77         MemorySegment.mismatch(s1, 3, 2, s2, 0, 0);
 78     }
 79 
 80     @Test(dataProvider = "slices", expectedExceptions = IndexOutOfBoundsException.class)
 81     public void testNegativeDstLength(MemorySegment s1, MemorySegment s2) {
 82         MemorySegment.mismatch(s1, 0, 0, s2, 3, 2);
 83     }
 84 
 85     @Test(dataProvider = "slices")
 86     public void testSameValues(MemorySegment ss1, MemorySegment ss2) {
 87         out.format("testSameValues s1:%s, s2:%s\n", ss1, ss2);
 88         MemorySegment s1 = initializeSegment(ss1);
 89         MemorySegment s2 = initializeSegment(ss2);
 90 
 91         if (s1.byteSize() == s2.byteSize()) {
 92             assertEquals(s1.mismatch(s2), -1);  // identical
 93             assertEquals(s2.mismatch(s1), -1);
 94         } else if (s1.byteSize() > s2.byteSize()) {
 95             assertEquals(s1.mismatch(s2), s2.byteSize());  // proper prefix
 96             assertEquals(s2.mismatch(s1), s2.byteSize());
 97         } else {
 98             assert s1.byteSize() < s2.byteSize();
 99             assertEquals(s1.mismatch(s2), s1.byteSize());  // proper prefix
100             assertEquals(s2.mismatch(s1), s1.byteSize());
101         }
102     }
103 
104     @Test(dataProvider = "slicesStatic")
105     public void testSameValuesStatic(SliceOffsetAndSize ss1, SliceOffsetAndSize ss2) {
106         out.format("testSameValuesStatic s1:%s, s2:%s\n", ss1, ss2);
107         MemorySegment s1 = initializeSegment(ss1.toSlice());
108         MemorySegment s2 = initializeSegment(ss2.toSlice());
109 
110         for (long i = ss2.offset ; i < ss2.size ; i++) {
111             long bytes = i - ss2.offset;
112             long expected = (bytes == ss1.size) ?
113                     -1 : Long.min(ss1.size, bytes);
114             assertEquals(MemorySegment.mismatch(ss1.segment, ss1.offset, ss1.endOffset(), ss2.segment, ss2.offset, i), expected);
115         }
116         for (long i = ss1.offset ; i < ss1.size ; i++) {
117             long bytes = i - ss1.offset;
118             long expected = (bytes == ss2.size) ?
119                     -1 : Long.min(ss2.size, bytes);
120             assertEquals(MemorySegment.mismatch(ss2.segment, ss2.offset, ss2.endOffset(), ss1.segment, ss1.offset, i), expected);
121         }
122     }
123 
124     @Test(dataProvider = "slices")
125     public void testDifferentValues(MemorySegment s1, MemorySegment s2) {
126         out.format("testDifferentValues s1:%s, s2:%s\n", s1, s2);
127         s1 = initializeSegment(s1);
128         s2 = initializeSegment(s2);
129 
130         for (long i = s2.byteSize() -1 ; i >= 0; i--) {
131             long expectedMismatchOffset = i;
132             s2.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
133 
134             if (s1.byteSize() == s2.byteSize()) {
135                 assertEquals(s1.mismatch(s2), expectedMismatchOffset);
136                 assertEquals(s2.mismatch(s1), expectedMismatchOffset);
137             } else if (s1.byteSize() > s2.byteSize()) {
138                 assertEquals(s1.mismatch(s2), expectedMismatchOffset);
139                 assertEquals(s2.mismatch(s1), expectedMismatchOffset);
140             } else {
141                 assert s1.byteSize() < s2.byteSize();
142                 var off = Math.min(s1.byteSize(), expectedMismatchOffset);
143                 assertEquals(s1.mismatch(s2), off);  // proper prefix
144                 assertEquals(s2.mismatch(s1), off);
145             }
146         }
147     }
148 
149     @Test(dataProvider = "slicesStatic")
150     public void testDifferentValuesStatic(SliceOffsetAndSize ss1, SliceOffsetAndSize ss2) {
151         out.format("testDifferentValues s1:%s, s2:%s\n", ss1, ss2);
152 
153         for (long i = ss2.size - 1 ; i >= 0; i--) {
154             if (i >= ss1.size) continue;
155             initializeSegment(ss1.toSlice());
156             initializeSegment(ss2.toSlice());
157             long expectedMismatchOffset = i;
158             ss2.toSlice().set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
159 
160             for (long j = expectedMismatchOffset + 1 ; j < ss2.size ; j++) {
161                 assertEquals(MemorySegment.mismatch(ss1.segment, ss1.offset, ss1.endOffset(), ss2.segment, ss2.offset, j + ss2.offset), expectedMismatchOffset);
162             }
163             for (long j = expectedMismatchOffset + 1 ; j < ss1.size ; j++) {
164                 assertEquals(MemorySegment.mismatch(ss2.segment, ss2.offset, ss2.endOffset(), ss1.segment, ss1.offset, j + ss1.offset), expectedMismatchOffset);
165             }
166         }
167     }
168 
169     @Test
170     public void testEmpty() {
171         var s1 = MemorySegment.ofArray(new byte[0]);
172         assertEquals(s1.mismatch(s1), -1);
173         try (Arena arena = Arena.ofConfined()) {
174             var nativeSegment = arena.allocate(4, 4);;
175             var s2 = nativeSegment.asSlice(0, 0);
176             assertEquals(s1.mismatch(s2), -1);
177             assertEquals(s2.mismatch(s1), -1);
178         }
179     }
180 
181     @Test
182     public void testLarge() {
183         // skip if not on 64 bits
184         if (ValueLayout.ADDRESS.byteSize() > 32) {
185             try (Arena arena = Arena.ofConfined()) {
186                 var s1 = arena.allocate((long) Integer.MAX_VALUE + 10L, 8);;
187                 var s2 = arena.allocate((long) Integer.MAX_VALUE + 10L, 8);;
188                 assertEquals(s1.mismatch(s1), -1);
189                 assertEquals(s1.mismatch(s2), -1);
190                 assertEquals(s2.mismatch(s1), -1);
191 
192                 testLargeAcrossMaxBoundary(s1, s2);
193 
194                 testLargeMismatchAcrossMaxBoundary(s1, s2);
195             }
196         }
197     }
198 
199     private void testLargeAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
200         for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
201             var s3 = s1.asSlice(0, i);
202             var s4 = s2.asSlice(0, i);
203             // instance
204             assertEquals(s3.mismatch(s3), -1);
205             assertEquals(s3.mismatch(s4), -1);
206             assertEquals(s4.mismatch(s3), -1);
207             // static
208             assertEquals(MemorySegment.mismatch(s1, 0, s1.byteSize(), s1, 0, i), -1);
209             assertEquals(MemorySegment.mismatch(s2, 0, s1.byteSize(), s1, 0, i), -1);
210             assertEquals(MemorySegment.mismatch(s1, 0, s1.byteSize(), s2, 0, i), -1);
211         }
212     }
213 
214     private void testLargeMismatchAcrossMaxBoundary(MemorySegment s1, MemorySegment s2) {
215         for (long i = s2.byteSize() -1 ; i >= Integer.MAX_VALUE - 10L; i--) {
216             s2.set(ValueLayout.JAVA_BYTE, i, (byte) 0xFF);
217             long expectedMismatchOffset = i;
218             assertEquals(s1.mismatch(s2), expectedMismatchOffset);
219             assertEquals(s2.mismatch(s1), expectedMismatchOffset);
220         }
221     }
222 
223     static final Class<IllegalStateException> ISE = IllegalStateException.class;
224     static final Class<UnsupportedOperationException> UOE = UnsupportedOperationException.class;
225 
226     @Test
227     public void testClosed() {
228         MemorySegment s1, s2;
229         try (Arena arena = Arena.ofConfined()) {
230             s1 = arena.allocate(4, 1);
231             s2 = arena.allocate(4, 1);;
232         }
233         assertThrows(ISE, () -> s1.mismatch(s1));
234         assertThrows(ISE, () -> s1.mismatch(s2));
235         assertThrows(ISE, () -> s2.mismatch(s1));
236     }
237 
238     @Test
239     public void testThreadAccess() throws Exception {
240         try (Arena arena = Arena.ofConfined()) {
241             var segment = arena.allocate(4, 1);;
242             {
243                 AtomicReference<RuntimeException> exception = new AtomicReference<>();
244                 Runnable action = () -> {
245                     try {
246                         MemorySegment.ofArray(new byte[4]).mismatch(segment);
247                     } catch (RuntimeException e) {
248                         exception.set(e);
249                     }
250                 };
251                 Thread thread = new Thread(action);
252                 thread.start();
253                 thread.join();
254 
255                 RuntimeException e = exception.get();
256                 if (!(e instanceof WrongThreadException)) {
257                     throw e;
258                 }
259             }
260             {
261                 AtomicReference<RuntimeException> exception = new AtomicReference<>();
262                 Runnable action = () -> {
263                     try {
264                         segment.mismatch(MemorySegment.ofArray(new byte[4]));
265                     } catch (RuntimeException e) {
266                         exception.set(e);
267                     }
268                 };
269                 Thread thread = new Thread(action);
270                 thread.start();
271                 thread.join();
272 
273                 RuntimeException e = exception.get();
274                 if (!(e instanceof WrongThreadException)) {
275                     throw e;
276                 }
277             }
278         }
279     }
280 
281     enum SegmentKind {
282         NATIVE(i -> Arena.ofAuto().allocate(i, 1)),
283         ARRAY(i -> MemorySegment.ofArray(new byte[i]));
284 
285         final IntFunction<MemorySegment> segmentFactory;
286 
287         SegmentKind(IntFunction<MemorySegment> segmentFactory) {
288             this.segmentFactory = segmentFactory;
289         }
290 
291         MemorySegment makeSegment(int elems) {
292             return segmentFactory.apply(elems);
293         }
294     }
295 
296     record SliceOffsetAndSize(MemorySegment segment, long offset, long size) {
297         MemorySegment toSlice() {
298             return segment.asSlice(offset, size);
299         }
300         long endOffset() {
301             return offset + size;
302         }
303     };
304 
305     @DataProvider(name = "slicesStatic")
306     static Object[][] slicesStatic() {
307         int[] sizes = { 16, 8, 1 };
308         List<SliceOffsetAndSize> aSliceOffsetAndSizes = new ArrayList<>();
309         List<SliceOffsetAndSize> bSliceOffsetAndSizes = new ArrayList<>();
310         for (List<SliceOffsetAndSize> slices : List.of(aSliceOffsetAndSizes, bSliceOffsetAndSizes)) {
311             for (SegmentKind kind : SegmentKind.values()) {
312                 MemorySegment segment = kind.makeSegment(16);
313                 //compute all slices
314                 for (int size : sizes) {
315                     for (int index = 0 ; index < 16 ; index += size) {
316                         slices.add(new SliceOffsetAndSize(segment, index, size));
317                     }
318                 }
319             }
320         }
321         assert aSliceOffsetAndSizes.size() == bSliceOffsetAndSizes.size();
322         Object[][] sliceArray = new Object[aSliceOffsetAndSizes.size() * bSliceOffsetAndSizes.size()][];
323         for (int i = 0 ; i < aSliceOffsetAndSizes.size() ; i++) {
324             for (int j = 0 ; j < bSliceOffsetAndSizes.size() ; j++) {
325                 sliceArray[i * aSliceOffsetAndSizes.size() + j] = new Object[] { aSliceOffsetAndSizes.get(i), bSliceOffsetAndSizes.get(j) };
326             }
327         }
328         return sliceArray;
329     }
330 
331     @DataProvider(name = "slices")
332     static Object[][] slices() {
333         Object[][] slicesStatic = slicesStatic();
334         return Stream.of(slicesStatic)
335                 .map(arr -> new Object[]{
336                         ((SliceOffsetAndSize) arr[0]).toSlice(),
337                         ((SliceOffsetAndSize) arr[1]).toSlice()
338                 }).toArray(Object[][]::new);
339     }
340 }