1 /* 2 * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @enablePreview 27 * @run testng TestSpliterator 28 */ 29 30 import java.lang.foreign.*; 31 32 import java.lang.invoke.VarHandle; 33 import java.util.LinkedList; 34 import java.util.List; 35 import java.util.Spliterator; 36 import java.util.concurrent.CountedCompleter; 37 import java.util.concurrent.RecursiveTask; 38 import java.util.concurrent.atomic.AtomicLong; 39 import java.util.stream.LongStream; 40 41 import org.testng.annotations.*; 42 43 import static org.testng.Assert.*; 44 45 public class TestSpliterator { 46 47 static final VarHandle INT_HANDLE = ValueLayout.JAVA_INT.arrayElementVarHandle(); 48 49 final static int CARRIER_SIZE = 4; 50 51 @Test(dataProvider = "splits") 52 public void testSum(int size, int threshold) { 53 SequenceLayout layout = MemoryLayout.sequenceLayout(size, ValueLayout.JAVA_INT); 54 55 //setup 56 try (Arena arena = Arena.ofShared()) { 57 MemorySegment segment = arena.allocate(layout);; 58 for (int i = 0; i < layout.elementCount(); i++) { 59 INT_HANDLE.set(segment, (long) i, i); 60 } 61 long expected = LongStream.range(0, layout.elementCount()).sum(); 62 //serial 63 long serial = sum(0, segment); 64 assertEquals(serial, expected); 65 //parallel counted completer 66 long parallelCounted = new SumSegmentCounted(null, segment.spliterator(layout.elementLayout()), threshold).invoke(); 67 assertEquals(parallelCounted, expected); 68 //parallel recursive action 69 long parallelRecursive = new SumSegmentRecursive(segment.spliterator(layout.elementLayout()), threshold).invoke(); 70 assertEquals(parallelRecursive, expected); 71 //parallel stream 72 long streamParallel = segment.elements(layout.elementLayout()).parallel() 73 .reduce(0L, TestSpliterator::sumSingle, Long::sum); 74 assertEquals(streamParallel, expected); 75 } 76 } 77 78 @Test 79 public void testSumSameThread() { 80 SequenceLayout layout = MemoryLayout.sequenceLayout(1024, ValueLayout.JAVA_INT); 81 82 //setup 83 Arena scope = Arena.ofAuto(); 84 MemorySegment segment = scope.allocate(layout); 85 for (int i = 0; i < layout.elementCount(); i++) { 86 INT_HANDLE.set(segment, (long) i, i); 87 } 88 long expected = LongStream.range(0, layout.elementCount()).sum(); 89 90 //check that a segment w/o ACQUIRE access mode can still be used from same thread 91 AtomicLong spliteratorSum = new AtomicLong(); 92 segment.spliterator(layout.elementLayout()) 93 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s))); 94 assertEquals(spliteratorSum.get(), expected); 95 } 96 97 @Test(expectedExceptions = IllegalArgumentException.class) 98 public void testBadSpliteratorElementSizeTooBig() { 99 Arena scope = Arena.ofAuto(); 100 scope.allocate(2, 1) 101 .spliterator(ValueLayout.JAVA_INT); 102 } 103 104 @Test(expectedExceptions = IllegalArgumentException.class) 105 public void testBadStreamElementSizeTooBig() { 106 Arena scope = Arena.ofAuto(); 107 scope.allocate(2, 1) 108 .elements(ValueLayout.JAVA_INT); 109 } 110 111 @Test(expectedExceptions = IllegalArgumentException.class) 112 public void testBadSpliteratorElementSizeNotMultiple() { 113 Arena scope = Arena.ofAuto(); 114 scope.allocate(7, 1) 115 .spliterator(ValueLayout.JAVA_INT); 116 } 117 118 @Test(expectedExceptions = IllegalArgumentException.class) 119 public void testBadStreamElementSizeNotMultiple() { 120 Arena scope = Arena.ofAuto(); 121 scope.allocate(7, 1) 122 .elements(ValueLayout.JAVA_INT); 123 } 124 125 @Test 126 public void testSpliteratorElementSizeMultipleButNotPowerOfTwo() { 127 Arena scope = Arena.ofAuto(); 128 scope.allocate(12, 1) 129 .spliterator(ValueLayout.JAVA_INT); 130 } 131 132 @Test 133 public void testStreamElementSizeMultipleButNotPowerOfTwo() { 134 Arena scope = Arena.ofAuto(); 135 scope.allocate(12, 1) 136 .elements(ValueLayout.JAVA_INT); 137 } 138 139 @Test(expectedExceptions = IllegalArgumentException.class) 140 public void testBadSpliteratorElementSizeZero() { 141 Arena scope = Arena.ofAuto(); 142 scope.allocate(7, 1) 143 .spliterator(MemoryLayout.sequenceLayout(0, ValueLayout.JAVA_INT)); 144 } 145 146 @Test(expectedExceptions = IllegalArgumentException.class) 147 public void testBadStreamElementSizeZero() { 148 Arena scope = Arena.ofAuto(); 149 scope.allocate(7, 1) 150 .elements(MemoryLayout.sequenceLayout(0, ValueLayout.JAVA_INT)); 151 } 152 153 @Test(expectedExceptions = IllegalArgumentException.class) 154 public void testHyperAligned() { 155 Arena scope = Arena.ofAuto(); 156 MemorySegment segment = scope.allocate(8, 1); 157 // compute an alignment constraint (in bytes) which exceed that of the native segment 158 long bigByteAlign = Long.lowestOneBit(segment.address()) << 1; 159 segment.elements(MemoryLayout.sequenceLayout(2, ValueLayout.JAVA_INT.withByteAlignment(bigByteAlign))); 160 } 161 162 static long sumSingle(long acc, MemorySegment segment) { 163 return acc + (int)INT_HANDLE.get(segment, 0L); 164 } 165 166 static long sum(long start, MemorySegment segment) { 167 long sum = start; 168 int length = (int)segment.byteSize(); 169 for (int i = 0 ; i < length / CARRIER_SIZE ; i++) { 170 sum += (int)INT_HANDLE.get(segment, (long)i); 171 } 172 return sum; 173 } 174 175 static class SumSegmentCounted extends CountedCompleter<Long> { 176 177 final long threshold; 178 long localSum = 0; 179 List<SumSegmentCounted> children = new LinkedList<>(); 180 181 private Spliterator<MemorySegment> segmentSplitter; 182 183 SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) { 184 super(parent); 185 this.segmentSplitter = segmentSplitter; 186 this.threshold = threshold; 187 } 188 189 @Override 190 public void compute() { 191 Spliterator<MemorySegment> sub; 192 while (segmentSplitter.estimateSize() > threshold && 193 (sub = segmentSplitter.trySplit()) != null) { 194 addToPendingCount(1); 195 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold); 196 children.add(child); 197 child.fork(); 198 } 199 segmentSplitter.forEachRemaining(slice -> { 200 localSum += sumSingle(0, slice); 201 }); 202 tryComplete(); 203 } 204 205 @Override 206 public Long getRawResult() { 207 long sum = localSum; 208 for (SumSegmentCounted c : children) { 209 sum += c.getRawResult(); 210 } 211 return sum; 212 } 213 } 214 215 static class SumSegmentRecursive extends RecursiveTask<Long> { 216 217 final long threshold; 218 private final Spliterator<MemorySegment> splitter; 219 private long result; 220 221 SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) { 222 this.splitter = splitter; 223 this.threshold = threshold; 224 } 225 226 @Override 227 protected Long compute() { 228 if (splitter.estimateSize() > threshold) { 229 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold); 230 sub.fork(); 231 return compute() + sub.join(); 232 } else { 233 splitter.forEachRemaining(slice -> { 234 result += sumSingle(0, slice); 235 }); 236 return result; 237 } 238 } 239 } 240 241 @DataProvider(name = "splits") 242 public Object[][] splits() { 243 return new Object[][] { 244 { 10, 1 }, 245 { 100, 1 }, 246 { 1000, 1 }, 247 { 10000, 1 }, 248 { 10, 10 }, 249 { 100, 10 }, 250 { 1000, 10 }, 251 { 10000, 10 }, 252 { 10, 100 }, 253 { 100, 100 }, 254 { 1000, 100 }, 255 { 10000, 100 }, 256 { 10, 1000 }, 257 { 100, 1000 }, 258 { 1000, 1000 }, 259 { 10000, 1000 }, 260 { 10, 10000 }, 261 { 100, 10000 }, 262 { 1000, 10000 }, 263 { 10000, 10000 }, 264 }; 265 } 266 }