1 /* 2 * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package org.openjdk.bench.java.lang.foreign; 25 26 import java.lang.foreign.*; 27 28 import sun.misc.Unsafe; 29 import org.openjdk.jmh.annotations.Benchmark; 30 import org.openjdk.jmh.annotations.BenchmarkMode; 31 import org.openjdk.jmh.annotations.Fork; 32 import org.openjdk.jmh.annotations.Measurement; 33 import org.openjdk.jmh.annotations.Mode; 34 import org.openjdk.jmh.annotations.OutputTimeUnit; 35 import org.openjdk.jmh.annotations.Setup; 36 import org.openjdk.jmh.annotations.State; 37 import org.openjdk.jmh.annotations.TearDown; 38 import org.openjdk.jmh.annotations.Warmup; 39 40 import java.util.LinkedList; 41 import java.util.List; 42 import java.util.Optional; 43 import java.util.Spliterator; 44 import java.util.concurrent.CountedCompleter; 45 import java.util.concurrent.RecursiveTask; 46 import java.util.concurrent.TimeUnit; 47 import java.util.function.Predicate; 48 import java.util.function.ToIntFunction; 49 50 @BenchmarkMode(Mode.AverageTime) 51 @Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS) 52 @Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) 53 @State(org.openjdk.jmh.annotations.Scope.Thread) 54 @OutputTimeUnit(TimeUnit.MILLISECONDS) 55 @Fork(value = 3, jvmArgsAppend = "--enable-preview") 56 public class ParallelSum extends JavaLayouts { 57 58 final static int CARRIER_SIZE = 4; 59 final static int ALLOC_SIZE = CARRIER_SIZE * 1024 * 1024 * 256; 60 final static int ELEM_SIZE = ALLOC_SIZE / CARRIER_SIZE; 61 62 final static MemoryLayout ELEM_LAYOUT = ValueLayout.JAVA_INT; 63 final static int BULK_FACTOR = 512; 64 final static SequenceLayout ELEM_LAYOUT_BULK = MemoryLayout.sequenceLayout(BULK_FACTOR, ELEM_LAYOUT); 65 66 static final Unsafe unsafe = Utils.unsafe; 67 68 Arena arena; 69 MemorySegment segment; 70 long address; 71 72 @Setup 73 public void setup() { 74 address = unsafe.allocateMemory(ALLOC_SIZE); 75 for (int i = 0; i < ELEM_SIZE; i++) { 76 unsafe.putInt(address + (i * CARRIER_SIZE), i); 77 } 78 arena = Arena.ofShared(); 79 segment = arena.allocate(ALLOC_SIZE, CARRIER_SIZE); 80 for (int i = 0; i < ELEM_SIZE; i++) { 81 VH_INT.set(segment, (long) i, i); 82 } 83 } 84 85 @TearDown 86 public void tearDown() throws Throwable { 87 unsafe.freeMemory(address); 88 arena.close(); 89 } 90 91 @Benchmark 92 public int segment_serial() { 93 int res = 0; 94 for (int i = 0; i < ELEM_SIZE; i++) { 95 res += (int)VH_INT.get(segment, (long) i); 96 } 97 return res; 98 } 99 100 @Benchmark 101 public int unsafe_serial() { 102 int res = 0; 103 for (int i = 0; i < ELEM_SIZE; i++) { 104 res += unsafe.getInt(address + (i * CARRIER_SIZE)); 105 } 106 return res; 107 } 108 109 @Benchmark 110 public int segment_parallel() { 111 return new SumSegment(segment.spliterator(ELEM_LAYOUT), SEGMENT_TO_INT).invoke(); 112 } 113 114 @Benchmark 115 public int segment_parallel_bulk() { 116 return new SumSegment(segment.spliterator(ELEM_LAYOUT_BULK), SEGMENT_TO_INT_BULK).invoke(); 117 } 118 119 @Benchmark 120 public int segment_stream_parallel() { 121 return segment.elements(ELEM_LAYOUT).parallel().mapToInt(SEGMENT_TO_INT).sum(); 122 } 123 124 @Benchmark 125 public int segment_stream_parallel_bulk() { 126 return segment.elements(ELEM_LAYOUT_BULK).parallel().mapToInt(SEGMENT_TO_INT_BULK).sum(); 127 } 128 129 final static ToIntFunction<MemorySegment> SEGMENT_TO_INT = slice -> 130 (int) VH_INT.get(slice, 0L); 131 132 final static ToIntFunction<MemorySegment> SEGMENT_TO_INT_BULK = slice -> { 133 int res = 0; 134 for (int i = 0; i < BULK_FACTOR ; i++) { 135 res += (int)VH_INT.get(slice, (long) i); 136 } 137 return res; 138 }; 139 140 @Benchmark 141 public Optional<MemorySegment> segment_stream_findany_serial() { 142 return segment.elements(ELEM_LAYOUT) 143 .filter(FIND_SINGLE) 144 .findAny(); 145 } 146 147 @Benchmark 148 public Optional<MemorySegment> segment_stream_findany_parallel() { 149 return segment.elements(ELEM_LAYOUT).parallel() 150 .filter(FIND_SINGLE) 151 .findAny(); 152 } 153 154 @Benchmark 155 public Optional<MemorySegment> segment_stream_findany_serial_bulk() { 156 return segment.elements(ELEM_LAYOUT_BULK) 157 .filter(FIND_BULK) 158 .findAny(); 159 } 160 161 @Benchmark 162 public Optional<MemorySegment> segment_stream_findany_parallel_bulk() { 163 return segment.elements(ELEM_LAYOUT_BULK).parallel() 164 .filter(FIND_BULK) 165 .findAny(); 166 } 167 168 final static Predicate<MemorySegment> FIND_SINGLE = slice -> 169 (int)VH_INT.get(slice, 0L) == (ELEM_SIZE - 1); 170 171 final static Predicate<MemorySegment> FIND_BULK = slice -> { 172 for (int i = 0; i < BULK_FACTOR ; i++) { 173 if ((int)VH_INT.get(slice, (long)i) == (ELEM_SIZE - 1)) { 174 return true; 175 } 176 } 177 return false; 178 }; 179 180 @Benchmark 181 public int unsafe_parallel() { 182 return new SumUnsafe(address, 0, ALLOC_SIZE / CARRIER_SIZE).invoke(); 183 } 184 185 static class SumUnsafe extends RecursiveTask<Integer> { 186 187 final static int SPLIT_THRESHOLD = 4 * 1024 * 8; 188 189 private final long address; 190 private final int start, length; 191 192 SumUnsafe(long address, int start, int length) { 193 this.address = address; 194 this.start = start; 195 this.length = length; 196 } 197 198 @Override 199 protected Integer compute() { 200 if (length > SPLIT_THRESHOLD) { 201 int rem = length % 2; 202 int split = length / 2; 203 int lobound = split; 204 int hibound = lobound + rem; 205 SumUnsafe s1 = new SumUnsafe(address, start, lobound); 206 SumUnsafe s2 = new SumUnsafe(address, start + lobound, hibound); 207 s1.fork(); 208 s2.fork(); 209 return s1.join() + s2.join(); 210 } else { 211 int res = 0; 212 for (int i = 0; i < length; i ++) { 213 res += unsafe.getInt(address + (start + i) * CARRIER_SIZE); 214 } 215 return res; 216 } 217 } 218 } 219 220 static class SumSegment extends CountedCompleter<Integer> { 221 222 final static int SPLIT_THRESHOLD = 1024 * 8; 223 224 int localSum = 0; 225 private final ToIntFunction<MemorySegment> mapper; 226 List<SumSegment> children = new LinkedList<>(); 227 228 private Spliterator<MemorySegment> segmentSplitter; 229 230 SumSegment(Spliterator<MemorySegment> segmentSplitter, ToIntFunction<MemorySegment> mapper) { 231 this(null, segmentSplitter, mapper); 232 } 233 234 SumSegment(SumSegment parent, Spliterator<MemorySegment> segmentSplitter, ToIntFunction<MemorySegment> mapper) { 235 super(parent); 236 this.segmentSplitter = segmentSplitter; 237 this.mapper = mapper; 238 } 239 240 @Override 241 public void compute() { 242 Spliterator<MemorySegment> sub; 243 while (segmentSplitter.estimateSize() > SPLIT_THRESHOLD && 244 (sub = segmentSplitter.trySplit()) != null) { 245 addToPendingCount(1); 246 SumSegment child = new SumSegment(this, sub, mapper); 247 children.add(child); 248 child.fork(); 249 } 250 segmentSplitter.forEachRemaining(s -> { 251 localSum += mapper.applyAsInt(s); 252 }); 253 propagateCompletion(); 254 } 255 256 @Override 257 public Integer getRawResult() { 258 int sum = localSum; 259 for (SumSegment c : children) { 260 sum += c.getRawResult(); 261 } 262 children = null; 263 return sum; 264 } 265 } 266 }