1 /*
  2  * Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package org.openjdk.bench.jdk.incubator.vector.operation;
 26 
 27 import jdk.incubator.vector.*;
 28 
 29 import org.openjdk.jmh.annotations.*;
 30 
 31 import java.util.concurrent.TimeUnit;
 32 
 33 
 34 // Inspired by "SIMDized sum of all bytes in the array"
 35 //   http://0x80.pl/notesen/2018-10-24-sse-sumbytes.html
 36 //
 37 // C/C++ equivalent: https://github.com/WojciechMula/toys/tree/master/sse-sumbytes
 38 //
 39 @BenchmarkMode(Mode.Throughput)
 40 @Warmup(iterations = 3, time = 1)
 41 @Measurement(iterations = 5, time = 1)
 42 @OutputTimeUnit(TimeUnit.MILLISECONDS)
 43 @State(Scope.Benchmark)
 44 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
 45 public class SumOfUnsignedBytes extends AbstractVectorBenchmark {
 46 
 47     @Param({"64", "1024", "4096"})
 48     int size;
 49 
 50     private byte[] data;
 51 
 52     @Setup
 53     public void init() {
 54         size = size + size % 32; // FIXME: process tails
 55         data = fillByte(size, i -> (byte)(int)i);
 56 
 57         int sum = scalar();
 58         assert vectorInt() == sum;
 59         assert vectorShort() == sum;
 60         //assert vectorByte() == sum;
 61         //assert vectorSAD() == sum;
 62     }
 63 
 64     @Benchmark
 65     public int scalar() {
 66         int sum = 0;
 67         for (int i = 0; i < data.length; i++) {
 68             sum += data[i] & 0xFF;
 69         }
 70         return sum;
 71     }
 72 
 73     // 1. 32-bit accumulators
 74     @Benchmark
 75     public int vectorInt() {
 76         final var lobyte_mask = IntVector.broadcast(I256, 0x000000FF);
 77 
 78         var acc = IntVector.zero(I256);
 79         for (int i = 0; i < data.length; i += B256.length()) {
 80             var vb = ByteVector.fromArray(B256, data, i);
 81             var vi = (IntVector)vb.reinterpretAsInts();
 82             for (int j = 0; j < 4; j++) {
 83                 var tj = vi.lanewise(VectorOperators.LSHR, j * 8).and(lobyte_mask);
 84                 acc = acc.add(tj);
 85             }
 86         }
 87         return (int)Integer.toUnsignedLong(acc.reduceLanes(VectorOperators.ADD));
 88     }
 89 
 90     // 2. 16-bit accumulators
 91     @Benchmark
 92     public int vectorShort() {
 93         final var lobyte_mask = ShortVector.broadcast(S256, (short) 0x00FF);
 94 
 95         // FIXME: overflow
 96         var acc = ShortVector.zero(S256);
 97         for (int i = 0; i < data.length; i += B256.length()) {
 98             var vb = ByteVector.fromArray(B256, data, i);
 99             var vs = (ShortVector)vb.reinterpretAsShorts();
100             for (int j = 0; j < 2; j++) {
101                 var tj = vs.lanewise(VectorOperators.LSHR, j * 8).and(lobyte_mask);
102                 acc = acc.add(tj);
103             }
104         }
105 
106         int mid = S128.length();
107         var accLo = ((IntVector)(acc.reinterpretShape(S128, 0).castShape(I256, 0))).and(0xFFFF); // low half as ints
108         var accHi = ((IntVector)(acc.reinterpretShape(S128, 1).castShape(I256, 0))).and(0xFFFF); // high half as ints
109         return accLo.reduceLanes(VectorOperators.ADD) + accHi.reduceLanes(VectorOperators.ADD);
110     }
111 
112     /*
113     // 3. 8-bit halves (MISSING: _mm_adds_epu8)
114     @Benchmark
115     public int vectorByte() {
116         int window = 256;
117         var acc_hi  = IntVector.zero(I256);
118         var acc8_lo = ByteVector.zero(B256);
119         for (int i = 0; i < data.length; i += window) {
120             var acc8_hi = ByteVector.zero(B256);
121             int limit = Math.min(window, data.length - i);
122             for (int j = 0; j < limit; j += B256.length()) {
123                 var vb = ByteVector.fromArray(B256, data, i + j);
124 
125                 var t0 = acc8_lo.add(vb);
126                 var t1 = addSaturated(acc8_lo, vb); // MISSING
127                 var overflow = t0.notEqual(t1);
128 
129                 acc8_lo = t0;
130                 acc8_hi = acc8_hi.add((byte) 1, overflow);
131             }
132             acc_hi = acc_hi.add(sum(acc8_hi));
133         }
134         return sum(acc8_lo)
135                 .add(acc_hi.mul(256)) // overflow
136                 .addAll();
137     }
138 
139     // 4. Sum Of Absolute Differences (SAD) (MISSING: VPSADBW, _mm256_sad_epu8)
140     public int vectorSAD() {
141         var acc = IntVector.zero(I256);
142         for (int i = 0; i < data.length; i += B256.length()) {
143             var v = ByteVector.fromArray(B256, data, i);
144             var sad = sumOfAbsoluteDifferences(v, ByteVector.zero(B256)); // MISSING
145             acc = acc.add(sad);
146         }
147         return acc.addAll();
148     } */
149 
150     // Helpers
151     /*
152     static ByteVector addSaturated(ByteVector va, ByteVector vb) {
153         var vc = ByteVector.zero(B256);
154         for (int i = 0; i < B256.length(); i++) {
155             if ((va.get(i) & 0xFF) + (vb.get(i) & 0xFF) < 0xFF) {
156                 vc = vc.withLane(i, (byte)(va.get(i) + vb.get(i)));
157             } else {
158                 vc = vc.withLane(i, (byte)0xFF);
159             }
160         }
161         return vc;
162     }
163     IntVector sumOfAbsoluteDifferences(ByteVector va, ByteVector vb) {
164         var vc = ByteVector.zero(B256);
165         for (int i = 0; i < B256.length(); i++) {
166             if ((va.get(i) & 0xFF) > (vb.get(i) & 0xFF)) {
167                 vc = vc.withLane(i, (byte)(va.get(i) - vb.get(i)));
168             } else {
169                 vc = vc.withLane(i, (byte)(vb.get(i) - va.get(i)));
170             }
171         }
172         return sum(vc);
173     } */
174 }