1 /*
  2  * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package org.openjdk.bench.jdk.incubator.vector.crypto;
 25 
 26 import org.openjdk.jmh.annotations.*;
 27 import java.lang.foreign.MemorySegment;
 28 import jdk.incubator.vector.*;
 29 
 30 import java.nio.ByteOrder;
 31 import java.util.Arrays;
 32 
 33 @State(Scope.Thread)
 34 @BenchmarkMode(Mode.Throughput)
 35 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
 36 @Warmup(iterations = 3, time = 3)
 37 @Measurement(iterations = 8, time = 2)
 38 public class ChaChaBench {
 39 
 40     @Param({"16384", "65536"})
 41     private int dataSize;
 42 
 43     private ChaChaVector cc20_S128 = makeCC20(VectorShape.S_128_BIT);
 44     private ChaChaVector cc20_S256 = makeCC20(VectorShape.S_256_BIT);
 45     private ChaChaVector cc20_S512 = makeCC20(VectorShape.S_512_BIT);
 46 
 47     private MemorySegment in;
 48     private MemorySegment out;
 49 
 50     private byte[] key = new byte[32];
 51     private byte[] nonce = new byte[12];
 52     private long counter = 0;
 53 
 54     private static ChaChaVector makeCC20(VectorShape shape) {
 55         ChaChaVector cc20 = new ChaChaVector(shape);
 56         runKAT(cc20);
 57         return cc20;
 58     }
 59 
 60     @Setup
 61     public void setup() {
 62         in = MemorySegment.ofArray(new byte[dataSize]);
 63         out = MemorySegment.ofArray(new byte[dataSize]);
 64     }
 65 
 66     @Benchmark
 67     public void encrypt128() {
 68         cc20_S128.chacha20(key, nonce, counter, in, out);
 69     }
 70 
 71     @Benchmark
 72     public void encrypt256() {
 73         cc20_S256.chacha20(key, nonce, counter, in, out);
 74     }
 75 
 76     @Benchmark
 77     public void encrypt512() {
 78         cc20_S512.chacha20(key, nonce, counter, in, out);
 79     }
 80 
 81     private static class ChaChaVector {
 82 
 83         private static final int[] STATE_CONSTANTS =
 84             new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574};
 85 
 86         private final VectorSpecies<Integer> intSpecies;
 87         private final int numBlocks;
 88 
 89         private final VectorShuffle<Integer> rot1;
 90         private final VectorShuffle<Integer> rot2;
 91         private final VectorShuffle<Integer> rot3;
 92 
 93         private final IntVector counterAdd;
 94 
 95         private final VectorShuffle<Integer> shuf0;
 96         private final VectorShuffle<Integer> shuf1;
 97         private final VectorShuffle<Integer> shuf2;
 98         private final VectorShuffle<Integer> shuf3;
 99 
100         private final VectorMask<Integer> mask0;
101         private final VectorMask<Integer> mask1;
102         private final VectorMask<Integer> mask2;
103         private final VectorMask<Integer> mask3;
104 
105         private final int[] state;
106 
107         public ChaChaVector(VectorShape shape) {
108             this.intSpecies = VectorSpecies.of(int.class, shape);
109             this.numBlocks = intSpecies.length() / 4;
110 
111             this.rot1 = makeRotate(1);
112             this.rot2 = makeRotate(2);
113             this.rot3 = makeRotate(3);
114 
115             this.counterAdd = makeCounterAdd();
116 
117             this.shuf0 = makeRearrangeShuffle(0);
118             this.shuf1 = makeRearrangeShuffle(1);
119             this.shuf2 = makeRearrangeShuffle(2);
120             this.shuf3 = makeRearrangeShuffle(3);
121 
122             this.mask0 = makeRearrangeMask(0);
123             this.mask1 = makeRearrangeMask(1);
124             this.mask2 = makeRearrangeMask(2);
125             this.mask3 = makeRearrangeMask(3);
126 
127             this.state = new int[numBlocks * 16];
128         }
129 
130         private VectorShuffle<Integer>  makeRotate(int amount) {
131             int[] shuffleArr = new int[intSpecies.length()];
132 
133             for (int i = 0; i < intSpecies.length(); i ++) {
134                 int offset = (i / 4) * 4;
135                 shuffleArr[i] = offset + ((i + amount) % 4);
136             }
137 
138             return VectorShuffle.fromValues(intSpecies, shuffleArr);
139         }
140 
141         private IntVector makeCounterAdd() {
142             int[] addArr = new int[intSpecies.length()];
143             for(int i = 0; i < numBlocks; i++) {
144                 addArr[4 * i] = numBlocks;
145             }
146             return IntVector.fromArray(intSpecies, addArr, 0);
147         }
148 
149         private VectorShuffle<Integer>  makeRearrangeShuffle(int order) {
150             int[] shuffleArr = new int[intSpecies.length()];
151             int start = order * 4;
152             for (int i = 0; i < shuffleArr.length; i++) {
153                 shuffleArr[i] = (i % 4) + start;
154             }
155             return VectorShuffle.fromArray(intSpecies, shuffleArr, 0);
156         }
157 
158         private VectorMask<Integer> makeRearrangeMask(int order) {
159             boolean[] maskArr = new boolean[intSpecies.length()];
160             int start = order * 4;
161             if (start < maskArr.length) {
162                 for (int i = 0; i < 4; i++) {
163                     maskArr[i + start] = true;
164                 }
165             }
166 
167             return VectorMask.fromValues(intSpecies, maskArr);
168         }
169 
170         public void makeState(byte[] key, byte[] nonce, long counter,
171             int[] out) {
172 
173             // first field is constants
174             for (int i = 0; i < 4; i++) {
175                 for (int j = 0; j < numBlocks; j++) {
176                     out[4*j + i] = STATE_CONSTANTS[i];
177                 }
178             }
179 
180             // second field is first part of key
181             int fieldStart = 4 * numBlocks;
182             for (int i = 0; i < 4; i++) {
183                 int keyInt = 0;
184                 for (int j = 0; j < 4; j++) {
185                     keyInt += (0xFF & key[4 * i + j]) << 8 * j;
186                 }
187                 for (int j = 0; j < numBlocks; j++) {
188                     out[fieldStart + j*4 + i] = keyInt;
189                 }
190             }
191 
192             // third field is second part of key
193             fieldStart = 8 * numBlocks;
194             for (int i = 0; i < 4; i++) {
195                 int keyInt = 0;
196                 for (int j = 0; j < 4; j++) {
197                     keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
198                 }
199 
200                 for (int j = 0; j < numBlocks; j++) {
201                     out[fieldStart + j*4 + i] = keyInt;
202                 }
203             }
204 
205             // fourth field is counter and nonce
206             fieldStart = 12 * numBlocks;
207             for (int j = 0; j < numBlocks; j++) {
208                 out[fieldStart + j*4] = (int) (counter + j);
209             }
210 
211             for (int i = 0; i < 3; i++) {
212                 int nonceInt = 0;
213                 for (int j = 0; j < 4; j++) {
214                     nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
215                 }
216 
217                 for (int j = 0; j < numBlocks; j++) {
218                     out[fieldStart + j*4 + 1 + i] = nonceInt;
219                 }
220             }
221         }
222 
223         public void chacha20(byte[] key, byte[] nonce, long counter,
224             MemorySegment in, MemorySegment out) {
225 
226             makeState(key, nonce, counter, state);
227 
228             int len = intSpecies.length();
229 
230             IntVector sa = IntVector.fromArray(intSpecies, state, 0);
231             IntVector sb = IntVector.fromArray(intSpecies, state, len);
232             IntVector sc = IntVector.fromArray(intSpecies, state, 2 * len);
233             IntVector sd = IntVector.fromArray(intSpecies, state, 3 * len);
234 
235             int stateLenBytes = state.length * 4;
236             int numStates = (((int) in.byteSize()) + stateLenBytes - 1) / stateLenBytes;
237             for (int j = 0; j < numStates; j++){
238 
239                 IntVector a = sa;
240                 IntVector b = sb;
241                 IntVector c = sc;
242                 IntVector d = sd;
243 
244                 for (int i = 0; i < 10; i++) {
245                     // first round
246                     a = a.add(b);
247                     d = d.lanewise(VectorOperators.XOR, a);
248                     d = d.lanewise(VectorOperators.ROL, 16);
249 
250                     c = c.add(d);
251                     b = b.lanewise(VectorOperators.XOR, c);
252                     b = b.lanewise(VectorOperators.ROL,12);
253 
254                     a = a.add(b);
255                     d = d.lanewise(VectorOperators.XOR, a);
256                     d = d.lanewise(VectorOperators.ROL,8);
257 
258                     c = c.add(d);
259                     b = b.lanewise(VectorOperators.XOR, c);
260                     b = b.lanewise(VectorOperators.ROL,7);
261 
262                     // makeRotate
263                     b = b.rearrange(rot1);
264                     c = c.rearrange(rot2);
265                     d = d.rearrange(rot3);
266 
267                     // second round
268                     a = a.add(b);
269                     d = d.lanewise(VectorOperators.XOR, a);
270                     d = d.lanewise(VectorOperators.ROL,16);
271 
272                     c = c.add(d);
273                     b = b.lanewise(VectorOperators.XOR, c);
274                     b = b.lanewise(VectorOperators.ROL,12);
275 
276                     a = a.add(b);
277                     d = d.lanewise(VectorOperators.XOR, a);
278                     d = d.lanewise(VectorOperators.ROL,8);
279 
280                     c = c.add(d);
281                     b = b.lanewise(VectorOperators.XOR, c);
282                     b = b.lanewise(VectorOperators.ROL,7);
283 
284                     // makeRotate
285                     b = b.rearrange(rot3);
286                     c = c.rearrange(rot2);
287                     d = d.rearrange(rot1);
288                 }
289 
290                 a = a.add(sa);
291                 b = b.add(sb);
292                 c = c.add(sc);
293                 d = d.add(sd);
294 
295                 // rearrange the vectors
296                 if (intSpecies.length() == 4) {
297                     // no rearrange needed
298                 } else if (intSpecies.length() == 8) {
299                     IntVector a_r =
300                             a.rearrange(shuf0).blend(b.rearrange(shuf0), mask1);
301                     IntVector b_r =
302                             c.rearrange(shuf0).blend(d.rearrange(shuf0), mask1);
303                     IntVector c_r =
304                             a.rearrange(shuf1).blend(b.rearrange(shuf1), mask1);
305                     IntVector d_r =
306                             c.rearrange(shuf1).blend(d.rearrange(shuf1), mask1);
307 
308                     a = a_r;
309                     b = b_r;
310                     c = c_r;
311                     d = d_r;
312                 } else if (intSpecies.length() == 16) {
313                     IntVector a_r = a;
314                     a_r = a_r.blend(b.rearrange(shuf0), mask1);
315                     a_r = a_r.blend(c.rearrange(shuf0), mask2);
316                     a_r = a_r.blend(d.rearrange(shuf0), mask3);
317 
318                     IntVector b_r = b;
319                     b_r = b_r.blend(a.rearrange(shuf1), mask0);
320                     b_r = b_r.blend(c.rearrange(shuf1), mask2);
321                     b_r = b_r.blend(d.rearrange(shuf1), mask3);
322 
323                     IntVector c_r = c;
324                     c_r = c_r.blend(a.rearrange(shuf2), mask0);
325                     c_r = c_r.blend(b.rearrange(shuf2), mask1);
326                     c_r = c_r.blend(d.rearrange(shuf2), mask3);
327 
328                     IntVector d_r = d;
329                     d_r = d_r.blend(a.rearrange(shuf3), mask0);
330                     d_r = d_r.blend(b.rearrange(shuf3), mask1);
331                     d_r = d_r.blend(c.rearrange(shuf3), mask2);
332 
333                     a = a_r;
334                     b = b_r;
335                     c = c_r;
336                     d = d_r;
337                 } else {
338                     throw new RuntimeException("not supported");
339                 }
340 
341                 // xor keystream with input
342                 int inOff = stateLenBytes * j;
343                 IntVector ina = IntVector.fromMemorySegment(intSpecies, in, inOff, ByteOrder.LITTLE_ENDIAN);
344                 IntVector inb = IntVector.fromMemorySegment(intSpecies, in, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN);
345                 IntVector inc = IntVector.fromMemorySegment(intSpecies, in, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN);
346                 IntVector ind = IntVector.fromMemorySegment(intSpecies, in, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN);
347 
348                 ina.lanewise(VectorOperators.XOR, a).intoMemorySegment(out, inOff, ByteOrder.LITTLE_ENDIAN);
349                 inb.lanewise(VectorOperators.XOR, b).intoMemorySegment(out, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN);
350                 inc.lanewise(VectorOperators.XOR, c).intoMemorySegment(out, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN);
351                 ind.lanewise(VectorOperators.XOR, d).intoMemorySegment(out, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN);
352 
353                 // increment counter
354                 sd = sd.add(counterAdd);
355             }
356         }
357 
358         public int implBlockSize() {
359             return numBlocks * 64;
360         }
361     }
362 
363     private static byte[] hexStringToByteArray(String str) {
364         byte[] result = new byte[str.length() / 2];
365         for (int i = 0; i < result.length; i++) {
366             result[i] = (byte) Character.digit(str.charAt(2 * i), 16);
367             result[i] <<= 4;
368             result[i] += Character.digit(str.charAt(2 * i + 1), 16);
369         }
370         return result;
371     }
372 
373     private static void runKAT(ChaChaVector cc20, String keyStr,
374         String nonceStr, long counter, String inStr, String outStr) {
375 
376         byte[] key = hexStringToByteArray(keyStr);
377         byte[] nonce = hexStringToByteArray(nonceStr);
378         byte[] in = hexStringToByteArray(inStr);
379         byte[] expOut = hexStringToByteArray(outStr);
380 
381         // implementation only works at multiples of some size
382         int blockSize = cc20.implBlockSize();
383 
384         int length = blockSize * ((in.length + blockSize - 1) / blockSize);
385         in = Arrays.copyOf(in, length);
386         byte[] out = new byte[length];
387 
388         cc20.chacha20(key, nonce, counter, MemorySegment.ofArray(in), MemorySegment.ofArray(out));
389 
390         byte[] actOut = new byte[expOut.length];
391         System.arraycopy(out, 0, actOut, 0, expOut.length);
392 
393         if (!Arrays.equals(out, 0, expOut.length, expOut, 0, expOut.length)) {
394             throw new RuntimeException("Incorrect result");
395         }
396     }
397 
398     /*
399      * ChaCha20 Known Answer Tests to ensure that the implementation is correct.
400      */
401     private static void runKAT(ChaChaVector cc20) {
402         runKAT(cc20,
403         "0000000000000000000000000000000000000000000000000000000000000001",
404         "000000000000000000000002",
405         1,
406         "416e79207375626d697373696f6e20746f20746865204945544620696e74656e" +
407         "6465642062792074686520436f6e7472696275746f7220666f72207075626c69" +
408         "636174696f6e20617320616c6c206f722070617274206f6620616e2049455446" +
409         "20496e7465726e65742d4472616674206f722052464320616e6420616e792073" +
410         "746174656d656e74206d6164652077697468696e2074686520636f6e74657874" +
411         "206f6620616e204945544620616374697669747920697320636f6e7369646572" +
412         "656420616e20224945544620436f6e747269627574696f6e222e205375636820" +
413         "73746174656d656e747320696e636c756465206f72616c2073746174656d656e" +
414         "747320696e20494554462073657373696f6e732c2061732077656c6c20617320" +
415         "7772697474656e20616e6420656c656374726f6e696320636f6d6d756e696361" +
416         "74696f6e73206d61646520617420616e792074696d65206f7220706c6163652c" +
417         "207768696368206172652061646472657373656420746f",
418         "a3fbf07df3fa2fde4f376ca23e82737041605d9f4f4f57bd8cff2c1d4b7955ec" +
419         "2a97948bd3722915c8f3d337f7d370050e9e96d647b7c39f56e031ca5eb6250d" +
420         "4042e02785ececfa4b4bb5e8ead0440e20b6e8db09d881a7c6132f420e527950" +
421         "42bdfa7773d8a9051447b3291ce1411c680465552aa6c405b7764d5e87bea85a" +
422         "d00f8449ed8f72d0d662ab052691ca66424bc86d2df80ea41f43abf937d3259d" +
423         "c4b2d0dfb48a6c9139ddd7f76966e928e635553ba76c5c879d7b35d49eb2e62b" +
424         "0871cdac638939e25e8a1e0ef9d5280fa8ca328b351c3c765989cbcf3daa8b6c" +
425         "cc3aaf9f3979c92b3720fc88dc95ed84a1be059c6499b9fda236e7e818b04b0b" +
426         "c39c1e876b193bfe5569753f88128cc08aaa9b63d1a16f80ef2554d7189c411f" +
427         "5869ca52c5b83fa36ff216b9c1d30062bebcfd2dc5bce0911934fda79a86f6e6" +
428         "98ced759c3ff9b6477338f3da4f9cd8514ea9982ccafb341b2384dd902f3d1ab" +
429         "7ac61dd29c6f21ba5b862f3730e37cfdc4fd806c22f221"
430         );
431     }
432 }