1 /*
2 * Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package org.openjdk.bench.jdk.incubator.vector.operation;
26
27 import jdk.incubator.vector.*;
28
29 import org.openjdk.jmh.annotations.*;
30
31 import java.util.concurrent.TimeUnit;
32
33
34 // Inspired by "SIMDized sum of all bytes in the array"
35 // http://0x80.pl/notesen/2018-10-24-sse-sumbytes.html
36 //
37 // C/C++ equivalent: https://github.com/WojciechMula/toys/tree/master/sse-sumbytes
38 //
39 @BenchmarkMode(Mode.Throughput)
40 @Warmup(iterations = 3, time = 1)
41 @Measurement(iterations = 5, time = 1)
42 @OutputTimeUnit(TimeUnit.MILLISECONDS)
43 @State(Scope.Benchmark)
44 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
45 public class SumOfUnsignedBytes extends AbstractVectorBenchmark {
46
47 @Param({"64", "1024", "4096"})
48 int size;
49
50 private byte[] data;
51
52 @Setup
53 public void init() {
54 size = size + size % 32; // FIXME: process tails
55 data = fillByte(size, i -> (byte)(int)i);
56
57 int sum = scalar();
58 assert vectorInt() == sum;
59 assert vectorShort() == sum;
60 //assert vectorByte() == sum;
61 //assert vectorSAD() == sum;
62 }
63
64 @Benchmark
65 public int scalar() {
66 int sum = 0;
67 for (int i = 0; i < data.length; i++) {
68 sum += data[i] & 0xFF;
69 }
70 return sum;
71 }
72
73 // 1. 32-bit accumulators
74 @Benchmark
75 public int vectorInt() {
76 final var lobyte_mask = IntVector.broadcast(I256, 0x000000FF);
77
78 var acc = IntVector.zero(I256);
79 for (int i = 0; i < data.length; i += B256.length()) {
80 var vb = ByteVector.fromArray(B256, data, i);
81 var vi = (IntVector)vb.reinterpretAsInts();
82 for (int j = 0; j < 4; j++) {
83 var tj = vi.lanewise(VectorOperators.LSHR, j * 8).and(lobyte_mask);
84 acc = acc.add(tj);
85 }
86 }
87 return (int)Integer.toUnsignedLong(acc.reduceLanes(VectorOperators.ADD));
88 }
89
90 // 2. 16-bit accumulators
91 @Benchmark
92 public int vectorShort() {
93 final var lobyte_mask = ShortVector.broadcast(S256, (short) 0x00FF);
94
95 // FIXME: overflow
96 var acc = ShortVector.zero(S256);
97 for (int i = 0; i < data.length; i += B256.length()) {
98 var vb = ByteVector.fromArray(B256, data, i);
99 var vs = (ShortVector)vb.reinterpretAsShorts();
100 for (int j = 0; j < 2; j++) {
101 var tj = vs.lanewise(VectorOperators.LSHR, j * 8).and(lobyte_mask);
102 acc = acc.add(tj);
103 }
104 }
105
106 int mid = S128.length();
107 var accLo = ((IntVector)(acc.reinterpretShape(S128, 0).castShape(I256, 0))).and(0xFFFF); // low half as ints
108 var accHi = ((IntVector)(acc.reinterpretShape(S128, 1).castShape(I256, 0))).and(0xFFFF); // high half as ints
109 return accLo.reduceLanes(VectorOperators.ADD) + accHi.reduceLanes(VectorOperators.ADD);
110 }
111
112 /*
113 // 3. 8-bit halves (MISSING: _mm_adds_epu8)
114 @Benchmark
115 public int vectorByte() {
116 int window = 256;
117 var acc_hi = IntVector.zero(I256);
118 var acc8_lo = ByteVector.zero(B256);
119 for (int i = 0; i < data.length; i += window) {
120 var acc8_hi = ByteVector.zero(B256);
121 int limit = Math.min(window, data.length - i);
122 for (int j = 0; j < limit; j += B256.length()) {
123 var vb = ByteVector.fromArray(B256, data, i + j);
124
125 var t0 = acc8_lo.add(vb);
126 var t1 = addSaturated(acc8_lo, vb); // MISSING
127 var overflow = t0.notEqual(t1);
128
129 acc8_lo = t0;
130 acc8_hi = acc8_hi.add((byte) 1, overflow);
131 }
132 acc_hi = acc_hi.add(sum(acc8_hi));
133 }
134 return sum(acc8_lo)
135 .add(acc_hi.mul(256)) // overflow
136 .addAll();
137 }
138
139 // 4. Sum Of Absolute Differences (SAD) (MISSING: VPSADBW, _mm256_sad_epu8)
140 public int vectorSAD() {
141 var acc = IntVector.zero(I256);
142 for (int i = 0; i < data.length; i += B256.length()) {
143 var v = ByteVector.fromArray(B256, data, i);
144 var sad = sumOfAbsoluteDifferences(v, ByteVector.zero(B256)); // MISSING
145 acc = acc.add(sad);
146 }
147 return acc.addAll();
148 } */
149
150 // Helpers
151 /*
152 static ByteVector addSaturated(ByteVector va, ByteVector vb) {
153 var vc = ByteVector.zero(B256);
154 for (int i = 0; i < B256.length(); i++) {
155 if ((va.get(i) & 0xFF) + (vb.get(i) & 0xFF) < 0xFF) {
156 vc = vc.withLane(i, (byte)(va.get(i) + vb.get(i)));
157 } else {
158 vc = vc.withLane(i, (byte)0xFF);
159 }
160 }
161 return vc;
162 }
163 IntVector sumOfAbsoluteDifferences(ByteVector va, ByteVector vb) {
164 var vc = ByteVector.zero(B256);
165 for (int i = 0; i < B256.length(); i++) {
166 if ((va.get(i) & 0xFF) > (vb.get(i) & 0xFF)) {
167 vc = vc.withLane(i, (byte)(va.get(i) - vb.get(i)));
168 } else {
169 vc = vc.withLane(i, (byte)(vb.get(i) - va.get(i)));
170 }
171 }
172 return sum(vc);
173 } */
174 }