1 /*
2 * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package org.openjdk.bench.jdk.incubator.vector.crypto;
25
26 import org.openjdk.jmh.annotations.*;
27 import java.lang.foreign.MemorySegment;
28 import jdk.incubator.vector.*;
29
30 import java.nio.ByteOrder;
31 import java.util.Arrays;
32
33 @State(Scope.Thread)
34 @BenchmarkMode(Mode.Throughput)
35 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
36 @Warmup(iterations = 3, time = 3)
37 @Measurement(iterations = 8, time = 2)
38 public class ChaChaBench {
39
40 @Param({"16384", "65536"})
41 private int dataSize;
42
43 private ChaChaVector cc20_S128 = makeCC20(VectorShape.S_128_BIT);
44 private ChaChaVector cc20_S256 = makeCC20(VectorShape.S_256_BIT);
45 private ChaChaVector cc20_S512 = makeCC20(VectorShape.S_512_BIT);
46
47 private MemorySegment in;
48 private MemorySegment out;
49
50 private byte[] key = new byte[32];
51 private byte[] nonce = new byte[12];
52 private long counter = 0;
53
54 private static ChaChaVector makeCC20(VectorShape shape) {
55 ChaChaVector cc20 = new ChaChaVector(shape);
56 runKAT(cc20);
57 return cc20;
58 }
59
60 @Setup
61 public void setup() {
62 in = MemorySegment.ofArray(new byte[dataSize]);
63 out = MemorySegment.ofArray(new byte[dataSize]);
64 }
65
66 @Benchmark
67 public void encrypt128() {
68 cc20_S128.chacha20(key, nonce, counter, in, out);
69 }
70
71 @Benchmark
72 public void encrypt256() {
73 cc20_S256.chacha20(key, nonce, counter, in, out);
74 }
75
76 @Benchmark
77 public void encrypt512() {
78 cc20_S512.chacha20(key, nonce, counter, in, out);
79 }
80
81 private static class ChaChaVector {
82
83 private static final int[] STATE_CONSTANTS =
84 new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574};
85
86 private final VectorSpecies<Integer> intSpecies;
87 private final int numBlocks;
88
89 private final VectorShuffle<Integer> rot1;
90 private final VectorShuffle<Integer> rot2;
91 private final VectorShuffle<Integer> rot3;
92
93 private final IntVector counterAdd;
94
95 private final VectorShuffle<Integer> shuf0;
96 private final VectorShuffle<Integer> shuf1;
97 private final VectorShuffle<Integer> shuf2;
98 private final VectorShuffle<Integer> shuf3;
99
100 private final VectorMask<Integer> mask0;
101 private final VectorMask<Integer> mask1;
102 private final VectorMask<Integer> mask2;
103 private final VectorMask<Integer> mask3;
104
105 private final int[] state;
106
107 public ChaChaVector(VectorShape shape) {
108 this.intSpecies = VectorSpecies.of(int.class, shape);
109 this.numBlocks = intSpecies.length() / 4;
110
111 this.rot1 = makeRotate(1);
112 this.rot2 = makeRotate(2);
113 this.rot3 = makeRotate(3);
114
115 this.counterAdd = makeCounterAdd();
116
117 this.shuf0 = makeRearrangeShuffle(0);
118 this.shuf1 = makeRearrangeShuffle(1);
119 this.shuf2 = makeRearrangeShuffle(2);
120 this.shuf3 = makeRearrangeShuffle(3);
121
122 this.mask0 = makeRearrangeMask(0);
123 this.mask1 = makeRearrangeMask(1);
124 this.mask2 = makeRearrangeMask(2);
125 this.mask3 = makeRearrangeMask(3);
126
127 this.state = new int[numBlocks * 16];
128 }
129
130 private VectorShuffle<Integer> makeRotate(int amount) {
131 int[] shuffleArr = new int[intSpecies.length()];
132
133 for (int i = 0; i < intSpecies.length(); i ++) {
134 int offset = (i / 4) * 4;
135 shuffleArr[i] = offset + ((i + amount) % 4);
136 }
137
138 return VectorShuffle.fromValues(intSpecies, shuffleArr);
139 }
140
141 private IntVector makeCounterAdd() {
142 int[] addArr = new int[intSpecies.length()];
143 for(int i = 0; i < numBlocks; i++) {
144 addArr[4 * i] = numBlocks;
145 }
146 return IntVector.fromArray(intSpecies, addArr, 0);
147 }
148
149 private VectorShuffle<Integer> makeRearrangeShuffle(int order) {
150 int[] shuffleArr = new int[intSpecies.length()];
151 int start = order * 4;
152 for (int i = 0; i < shuffleArr.length; i++) {
153 shuffleArr[i] = (i % 4) + start;
154 }
155 return VectorShuffle.fromArray(intSpecies, shuffleArr, 0);
156 }
157
158 private VectorMask<Integer> makeRearrangeMask(int order) {
159 boolean[] maskArr = new boolean[intSpecies.length()];
160 int start = order * 4;
161 if (start < maskArr.length) {
162 for (int i = 0; i < 4; i++) {
163 maskArr[i + start] = true;
164 }
165 }
166
167 return VectorMask.fromValues(intSpecies, maskArr);
168 }
169
170 public void makeState(byte[] key, byte[] nonce, long counter,
171 int[] out) {
172
173 // first field is constants
174 for (int i = 0; i < 4; i++) {
175 for (int j = 0; j < numBlocks; j++) {
176 out[4*j + i] = STATE_CONSTANTS[i];
177 }
178 }
179
180 // second field is first part of key
181 int fieldStart = 4 * numBlocks;
182 for (int i = 0; i < 4; i++) {
183 int keyInt = 0;
184 for (int j = 0; j < 4; j++) {
185 keyInt += (0xFF & key[4 * i + j]) << 8 * j;
186 }
187 for (int j = 0; j < numBlocks; j++) {
188 out[fieldStart + j*4 + i] = keyInt;
189 }
190 }
191
192 // third field is second part of key
193 fieldStart = 8 * numBlocks;
194 for (int i = 0; i < 4; i++) {
195 int keyInt = 0;
196 for (int j = 0; j < 4; j++) {
197 keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j;
198 }
199
200 for (int j = 0; j < numBlocks; j++) {
201 out[fieldStart + j*4 + i] = keyInt;
202 }
203 }
204
205 // fourth field is counter and nonce
206 fieldStart = 12 * numBlocks;
207 for (int j = 0; j < numBlocks; j++) {
208 out[fieldStart + j*4] = (int) (counter + j);
209 }
210
211 for (int i = 0; i < 3; i++) {
212 int nonceInt = 0;
213 for (int j = 0; j < 4; j++) {
214 nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j;
215 }
216
217 for (int j = 0; j < numBlocks; j++) {
218 out[fieldStart + j*4 + 1 + i] = nonceInt;
219 }
220 }
221 }
222
223 public void chacha20(byte[] key, byte[] nonce, long counter,
224 MemorySegment in, MemorySegment out) {
225
226 makeState(key, nonce, counter, state);
227
228 int len = intSpecies.length();
229
230 IntVector sa = IntVector.fromArray(intSpecies, state, 0);
231 IntVector sb = IntVector.fromArray(intSpecies, state, len);
232 IntVector sc = IntVector.fromArray(intSpecies, state, 2 * len);
233 IntVector sd = IntVector.fromArray(intSpecies, state, 3 * len);
234
235 int stateLenBytes = state.length * 4;
236 int numStates = (((int) in.byteSize()) + stateLenBytes - 1) / stateLenBytes;
237 for (int j = 0; j < numStates; j++){
238
239 IntVector a = sa;
240 IntVector b = sb;
241 IntVector c = sc;
242 IntVector d = sd;
243
244 for (int i = 0; i < 10; i++) {
245 // first round
246 a = a.add(b);
247 d = d.lanewise(VectorOperators.XOR, a);
248 d = d.lanewise(VectorOperators.ROL, 16);
249
250 c = c.add(d);
251 b = b.lanewise(VectorOperators.XOR, c);
252 b = b.lanewise(VectorOperators.ROL,12);
253
254 a = a.add(b);
255 d = d.lanewise(VectorOperators.XOR, a);
256 d = d.lanewise(VectorOperators.ROL,8);
257
258 c = c.add(d);
259 b = b.lanewise(VectorOperators.XOR, c);
260 b = b.lanewise(VectorOperators.ROL,7);
261
262 // makeRotate
263 b = b.rearrange(rot1);
264 c = c.rearrange(rot2);
265 d = d.rearrange(rot3);
266
267 // second round
268 a = a.add(b);
269 d = d.lanewise(VectorOperators.XOR, a);
270 d = d.lanewise(VectorOperators.ROL,16);
271
272 c = c.add(d);
273 b = b.lanewise(VectorOperators.XOR, c);
274 b = b.lanewise(VectorOperators.ROL,12);
275
276 a = a.add(b);
277 d = d.lanewise(VectorOperators.XOR, a);
278 d = d.lanewise(VectorOperators.ROL,8);
279
280 c = c.add(d);
281 b = b.lanewise(VectorOperators.XOR, c);
282 b = b.lanewise(VectorOperators.ROL,7);
283
284 // makeRotate
285 b = b.rearrange(rot3);
286 c = c.rearrange(rot2);
287 d = d.rearrange(rot1);
288 }
289
290 a = a.add(sa);
291 b = b.add(sb);
292 c = c.add(sc);
293 d = d.add(sd);
294
295 // rearrange the vectors
296 if (intSpecies.length() == 4) {
297 // no rearrange needed
298 } else if (intSpecies.length() == 8) {
299 IntVector a_r =
300 a.rearrange(shuf0).blend(b.rearrange(shuf0), mask1);
301 IntVector b_r =
302 c.rearrange(shuf0).blend(d.rearrange(shuf0), mask1);
303 IntVector c_r =
304 a.rearrange(shuf1).blend(b.rearrange(shuf1), mask1);
305 IntVector d_r =
306 c.rearrange(shuf1).blend(d.rearrange(shuf1), mask1);
307
308 a = a_r;
309 b = b_r;
310 c = c_r;
311 d = d_r;
312 } else if (intSpecies.length() == 16) {
313 IntVector a_r = a;
314 a_r = a_r.blend(b.rearrange(shuf0), mask1);
315 a_r = a_r.blend(c.rearrange(shuf0), mask2);
316 a_r = a_r.blend(d.rearrange(shuf0), mask3);
317
318 IntVector b_r = b;
319 b_r = b_r.blend(a.rearrange(shuf1), mask0);
320 b_r = b_r.blend(c.rearrange(shuf1), mask2);
321 b_r = b_r.blend(d.rearrange(shuf1), mask3);
322
323 IntVector c_r = c;
324 c_r = c_r.blend(a.rearrange(shuf2), mask0);
325 c_r = c_r.blend(b.rearrange(shuf2), mask1);
326 c_r = c_r.blend(d.rearrange(shuf2), mask3);
327
328 IntVector d_r = d;
329 d_r = d_r.blend(a.rearrange(shuf3), mask0);
330 d_r = d_r.blend(b.rearrange(shuf3), mask1);
331 d_r = d_r.blend(c.rearrange(shuf3), mask2);
332
333 a = a_r;
334 b = b_r;
335 c = c_r;
336 d = d_r;
337 } else {
338 throw new RuntimeException("not supported");
339 }
340
341 // xor keystream with input
342 int inOff = stateLenBytes * j;
343 IntVector ina = IntVector.fromMemorySegment(intSpecies, in, inOff, ByteOrder.LITTLE_ENDIAN);
344 IntVector inb = IntVector.fromMemorySegment(intSpecies, in, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN);
345 IntVector inc = IntVector.fromMemorySegment(intSpecies, in, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN);
346 IntVector ind = IntVector.fromMemorySegment(intSpecies, in, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN);
347
348 ina.lanewise(VectorOperators.XOR, a).intoMemorySegment(out, inOff, ByteOrder.LITTLE_ENDIAN);
349 inb.lanewise(VectorOperators.XOR, b).intoMemorySegment(out, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN);
350 inc.lanewise(VectorOperators.XOR, c).intoMemorySegment(out, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN);
351 ind.lanewise(VectorOperators.XOR, d).intoMemorySegment(out, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN);
352
353 // increment counter
354 sd = sd.add(counterAdd);
355 }
356 }
357
358 public int implBlockSize() {
359 return numBlocks * 64;
360 }
361 }
362
363 private static byte[] hexStringToByteArray(String str) {
364 byte[] result = new byte[str.length() / 2];
365 for (int i = 0; i < result.length; i++) {
366 result[i] = (byte) Character.digit(str.charAt(2 * i), 16);
367 result[i] <<= 4;
368 result[i] += (byte) Character.digit(str.charAt(2 * i + 1), 16);
369 }
370 return result;
371 }
372
373 private static void runKAT(ChaChaVector cc20, String keyStr,
374 String nonceStr, long counter, String inStr, String outStr) {
375
376 byte[] key = hexStringToByteArray(keyStr);
377 byte[] nonce = hexStringToByteArray(nonceStr);
378 byte[] in = hexStringToByteArray(inStr);
379 byte[] expOut = hexStringToByteArray(outStr);
380
381 // implementation only works at multiples of some size
382 int blockSize = cc20.implBlockSize();
383
384 int length = blockSize * ((in.length + blockSize - 1) / blockSize);
385 in = Arrays.copyOf(in, length);
386 byte[] out = new byte[length];
387
388 cc20.chacha20(key, nonce, counter, MemorySegment.ofArray(in), MemorySegment.ofArray(out));
389
390 byte[] actOut = new byte[expOut.length];
391 System.arraycopy(out, 0, actOut, 0, expOut.length);
392
393 if (!Arrays.equals(out, 0, expOut.length, expOut, 0, expOut.length)) {
394 throw new RuntimeException("Incorrect result");
395 }
396 }
397
398 /*
399 * ChaCha20 Known Answer Tests to ensure that the implementation is correct.
400 */
401 private static void runKAT(ChaChaVector cc20) {
402 runKAT(cc20,
403 "0000000000000000000000000000000000000000000000000000000000000001",
404 "000000000000000000000002",
405 1,
406 "416e79207375626d697373696f6e20746f20746865204945544620696e74656e" +
407 "6465642062792074686520436f6e7472696275746f7220666f72207075626c69" +
408 "636174696f6e20617320616c6c206f722070617274206f6620616e2049455446" +
409 "20496e7465726e65742d4472616674206f722052464320616e6420616e792073" +
410 "746174656d656e74206d6164652077697468696e2074686520636f6e74657874" +
411 "206f6620616e204945544620616374697669747920697320636f6e7369646572" +
412 "656420616e20224945544620436f6e747269627574696f6e222e205375636820" +
413 "73746174656d656e747320696e636c756465206f72616c2073746174656d656e" +
414 "747320696e20494554462073657373696f6e732c2061732077656c6c20617320" +
415 "7772697474656e20616e6420656c656374726f6e696320636f6d6d756e696361" +
416 "74696f6e73206d61646520617420616e792074696d65206f7220706c6163652c" +
417 "207768696368206172652061646472657373656420746f",
418 "a3fbf07df3fa2fde4f376ca23e82737041605d9f4f4f57bd8cff2c1d4b7955ec" +
419 "2a97948bd3722915c8f3d337f7d370050e9e96d647b7c39f56e031ca5eb6250d" +
420 "4042e02785ececfa4b4bb5e8ead0440e20b6e8db09d881a7c6132f420e527950" +
421 "42bdfa7773d8a9051447b3291ce1411c680465552aa6c405b7764d5e87bea85a" +
422 "d00f8449ed8f72d0d662ab052691ca66424bc86d2df80ea41f43abf937d3259d" +
423 "c4b2d0dfb48a6c9139ddd7f76966e928e635553ba76c5c879d7b35d49eb2e62b" +
424 "0871cdac638939e25e8a1e0ef9d5280fa8ca328b351c3c765989cbcf3daa8b6c" +
425 "cc3aaf9f3979c92b3720fc88dc95ed84a1be059c6499b9fda236e7e818b04b0b" +
426 "c39c1e876b193bfe5569753f88128cc08aaa9b63d1a16f80ef2554d7189c411f" +
427 "5869ca52c5b83fa36ff216b9c1d30062bebcfd2dc5bce0911934fda79a86f6e6" +
428 "98ced759c3ff9b6477338f3da4f9cd8514ea9982ccafb341b2384dd902f3d1ab" +
429 "7ac61dd29c6f21ba5b862f3730e37cfdc4fd806c22f221"
430 );
431 }
432 }