1 /* 2 * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package org.openjdk.bench.jdk.incubator.vector.crypto; 25 26 import org.openjdk.jmh.annotations.*; 27 import java.lang.foreign.MemorySegment; 28 import jdk.incubator.vector.*; 29 30 import java.nio.ByteOrder; 31 import java.util.Arrays; 32 33 @State(Scope.Thread) 34 @BenchmarkMode(Mode.Throughput) 35 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) 36 @Warmup(iterations = 3, time = 3) 37 @Measurement(iterations = 8, time = 2) 38 public class ChaChaBench { 39 40 @Param({"16384", "65536"}) 41 private int dataSize; 42 43 private ChaChaVector cc20_S128 = makeCC20(VectorShape.S_128_BIT); 44 private ChaChaVector cc20_S256 = makeCC20(VectorShape.S_256_BIT); 45 private ChaChaVector cc20_S512 = makeCC20(VectorShape.S_512_BIT); 46 47 private MemorySegment in; 48 private MemorySegment out; 49 50 private byte[] key = new byte[32]; 51 private byte[] nonce = new byte[12]; 52 private long counter = 0; 53 54 private static ChaChaVector makeCC20(VectorShape shape) { 55 ChaChaVector cc20 = new ChaChaVector(shape); 56 runKAT(cc20); 57 return cc20; 58 } 59 60 @Setup 61 public void setup() { 62 in = MemorySegment.ofArray(new byte[dataSize]); 63 out = MemorySegment.ofArray(new byte[dataSize]); 64 } 65 66 @Benchmark 67 public void encrypt128() { 68 cc20_S128.chacha20(key, nonce, counter, in, out); 69 } 70 71 @Benchmark 72 public void encrypt256() { 73 cc20_S256.chacha20(key, nonce, counter, in, out); 74 } 75 76 @Benchmark 77 public void encrypt512() { 78 cc20_S512.chacha20(key, nonce, counter, in, out); 79 } 80 81 private static class ChaChaVector { 82 83 private static final int[] STATE_CONSTANTS = 84 new int[]{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574}; 85 86 private final VectorSpecies<Integer> intSpecies; 87 private final int numBlocks; 88 89 private final VectorShuffle<Integer> rot1; 90 private final VectorShuffle<Integer> rot2; 91 private final VectorShuffle<Integer> rot3; 92 93 private final IntVector counterAdd; 94 95 private final VectorShuffle<Integer> shuf0; 96 private final VectorShuffle<Integer> shuf1; 97 private final VectorShuffle<Integer> shuf2; 98 private final VectorShuffle<Integer> shuf3; 99 100 private final VectorMask<Integer> mask0; 101 private final VectorMask<Integer> mask1; 102 private final VectorMask<Integer> mask2; 103 private final VectorMask<Integer> mask3; 104 105 private final int[] state; 106 107 public ChaChaVector(VectorShape shape) { 108 this.intSpecies = VectorSpecies.of(int.class, shape); 109 this.numBlocks = intSpecies.length() / 4; 110 111 this.rot1 = makeRotate(1); 112 this.rot2 = makeRotate(2); 113 this.rot3 = makeRotate(3); 114 115 this.counterAdd = makeCounterAdd(); 116 117 this.shuf0 = makeRearrangeShuffle(0); 118 this.shuf1 = makeRearrangeShuffle(1); 119 this.shuf2 = makeRearrangeShuffle(2); 120 this.shuf3 = makeRearrangeShuffle(3); 121 122 this.mask0 = makeRearrangeMask(0); 123 this.mask1 = makeRearrangeMask(1); 124 this.mask2 = makeRearrangeMask(2); 125 this.mask3 = makeRearrangeMask(3); 126 127 this.state = new int[numBlocks * 16]; 128 } 129 130 private VectorShuffle<Integer> makeRotate(int amount) { 131 int[] shuffleArr = new int[intSpecies.length()]; 132 133 for (int i = 0; i < intSpecies.length(); i ++) { 134 int offset = (i / 4) * 4; 135 shuffleArr[i] = offset + ((i + amount) % 4); 136 } 137 138 return VectorShuffle.fromValues(intSpecies, shuffleArr); 139 } 140 141 private IntVector makeCounterAdd() { 142 int[] addArr = new int[intSpecies.length()]; 143 for(int i = 0; i < numBlocks; i++) { 144 addArr[4 * i] = numBlocks; 145 } 146 return IntVector.fromArray(intSpecies, addArr, 0); 147 } 148 149 private VectorShuffle<Integer> makeRearrangeShuffle(int order) { 150 int[] shuffleArr = new int[intSpecies.length()]; 151 int start = order * 4; 152 for (int i = 0; i < shuffleArr.length; i++) { 153 shuffleArr[i] = (i % 4) + start; 154 } 155 return VectorShuffle.fromArray(intSpecies, shuffleArr, 0); 156 } 157 158 private VectorMask<Integer> makeRearrangeMask(int order) { 159 boolean[] maskArr = new boolean[intSpecies.length()]; 160 int start = order * 4; 161 if (start < maskArr.length) { 162 for (int i = 0; i < 4; i++) { 163 maskArr[i + start] = true; 164 } 165 } 166 167 return VectorMask.fromValues(intSpecies, maskArr); 168 } 169 170 public void makeState(byte[] key, byte[] nonce, long counter, 171 int[] out) { 172 173 // first field is constants 174 for (int i = 0; i < 4; i++) { 175 for (int j = 0; j < numBlocks; j++) { 176 out[4*j + i] = STATE_CONSTANTS[i]; 177 } 178 } 179 180 // second field is first part of key 181 int fieldStart = 4 * numBlocks; 182 for (int i = 0; i < 4; i++) { 183 int keyInt = 0; 184 for (int j = 0; j < 4; j++) { 185 keyInt += (0xFF & key[4 * i + j]) << 8 * j; 186 } 187 for (int j = 0; j < numBlocks; j++) { 188 out[fieldStart + j*4 + i] = keyInt; 189 } 190 } 191 192 // third field is second part of key 193 fieldStart = 8 * numBlocks; 194 for (int i = 0; i < 4; i++) { 195 int keyInt = 0; 196 for (int j = 0; j < 4; j++) { 197 keyInt += (0xFF & key[4 * (i + 4) + j]) << 8 * j; 198 } 199 200 for (int j = 0; j < numBlocks; j++) { 201 out[fieldStart + j*4 + i] = keyInt; 202 } 203 } 204 205 // fourth field is counter and nonce 206 fieldStart = 12 * numBlocks; 207 for (int j = 0; j < numBlocks; j++) { 208 out[fieldStart + j*4] = (int) (counter + j); 209 } 210 211 for (int i = 0; i < 3; i++) { 212 int nonceInt = 0; 213 for (int j = 0; j < 4; j++) { 214 nonceInt += (0xFF & nonce[4 * i + j]) << 8 * j; 215 } 216 217 for (int j = 0; j < numBlocks; j++) { 218 out[fieldStart + j*4 + 1 + i] = nonceInt; 219 } 220 } 221 } 222 223 public void chacha20(byte[] key, byte[] nonce, long counter, 224 MemorySegment in, MemorySegment out) { 225 226 makeState(key, nonce, counter, state); 227 228 int len = intSpecies.length(); 229 230 IntVector sa = IntVector.fromArray(intSpecies, state, 0); 231 IntVector sb = IntVector.fromArray(intSpecies, state, len); 232 IntVector sc = IntVector.fromArray(intSpecies, state, 2 * len); 233 IntVector sd = IntVector.fromArray(intSpecies, state, 3 * len); 234 235 int stateLenBytes = state.length * 4; 236 int numStates = (((int) in.byteSize()) + stateLenBytes - 1) / stateLenBytes; 237 for (int j = 0; j < numStates; j++){ 238 239 IntVector a = sa; 240 IntVector b = sb; 241 IntVector c = sc; 242 IntVector d = sd; 243 244 for (int i = 0; i < 10; i++) { 245 // first round 246 a = a.add(b); 247 d = d.lanewise(VectorOperators.XOR, a); 248 d = d.lanewise(VectorOperators.ROL, 16); 249 250 c = c.add(d); 251 b = b.lanewise(VectorOperators.XOR, c); 252 b = b.lanewise(VectorOperators.ROL,12); 253 254 a = a.add(b); 255 d = d.lanewise(VectorOperators.XOR, a); 256 d = d.lanewise(VectorOperators.ROL,8); 257 258 c = c.add(d); 259 b = b.lanewise(VectorOperators.XOR, c); 260 b = b.lanewise(VectorOperators.ROL,7); 261 262 // makeRotate 263 b = b.rearrange(rot1); 264 c = c.rearrange(rot2); 265 d = d.rearrange(rot3); 266 267 // second round 268 a = a.add(b); 269 d = d.lanewise(VectorOperators.XOR, a); 270 d = d.lanewise(VectorOperators.ROL,16); 271 272 c = c.add(d); 273 b = b.lanewise(VectorOperators.XOR, c); 274 b = b.lanewise(VectorOperators.ROL,12); 275 276 a = a.add(b); 277 d = d.lanewise(VectorOperators.XOR, a); 278 d = d.lanewise(VectorOperators.ROL,8); 279 280 c = c.add(d); 281 b = b.lanewise(VectorOperators.XOR, c); 282 b = b.lanewise(VectorOperators.ROL,7); 283 284 // makeRotate 285 b = b.rearrange(rot3); 286 c = c.rearrange(rot2); 287 d = d.rearrange(rot1); 288 } 289 290 a = a.add(sa); 291 b = b.add(sb); 292 c = c.add(sc); 293 d = d.add(sd); 294 295 // rearrange the vectors 296 if (intSpecies.length() == 4) { 297 // no rearrange needed 298 } else if (intSpecies.length() == 8) { 299 IntVector a_r = 300 a.rearrange(shuf0).blend(b.rearrange(shuf0), mask1); 301 IntVector b_r = 302 c.rearrange(shuf0).blend(d.rearrange(shuf0), mask1); 303 IntVector c_r = 304 a.rearrange(shuf1).blend(b.rearrange(shuf1), mask1); 305 IntVector d_r = 306 c.rearrange(shuf1).blend(d.rearrange(shuf1), mask1); 307 308 a = a_r; 309 b = b_r; 310 c = c_r; 311 d = d_r; 312 } else if (intSpecies.length() == 16) { 313 IntVector a_r = a; 314 a_r = a_r.blend(b.rearrange(shuf0), mask1); 315 a_r = a_r.blend(c.rearrange(shuf0), mask2); 316 a_r = a_r.blend(d.rearrange(shuf0), mask3); 317 318 IntVector b_r = b; 319 b_r = b_r.blend(a.rearrange(shuf1), mask0); 320 b_r = b_r.blend(c.rearrange(shuf1), mask2); 321 b_r = b_r.blend(d.rearrange(shuf1), mask3); 322 323 IntVector c_r = c; 324 c_r = c_r.blend(a.rearrange(shuf2), mask0); 325 c_r = c_r.blend(b.rearrange(shuf2), mask1); 326 c_r = c_r.blend(d.rearrange(shuf2), mask3); 327 328 IntVector d_r = d; 329 d_r = d_r.blend(a.rearrange(shuf3), mask0); 330 d_r = d_r.blend(b.rearrange(shuf3), mask1); 331 d_r = d_r.blend(c.rearrange(shuf3), mask2); 332 333 a = a_r; 334 b = b_r; 335 c = c_r; 336 d = d_r; 337 } else { 338 throw new RuntimeException("not supported"); 339 } 340 341 // xor keystream with input 342 int inOff = stateLenBytes * j; 343 IntVector ina = IntVector.fromMemorySegment(intSpecies, in, inOff, ByteOrder.LITTLE_ENDIAN); 344 IntVector inb = IntVector.fromMemorySegment(intSpecies, in, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN); 345 IntVector inc = IntVector.fromMemorySegment(intSpecies, in, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN); 346 IntVector ind = IntVector.fromMemorySegment(intSpecies, in, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN); 347 348 ina.lanewise(VectorOperators.XOR, a).intoMemorySegment(out, inOff, ByteOrder.LITTLE_ENDIAN); 349 inb.lanewise(VectorOperators.XOR, b).intoMemorySegment(out, inOff + 4L * len, ByteOrder.LITTLE_ENDIAN); 350 inc.lanewise(VectorOperators.XOR, c).intoMemorySegment(out, inOff + 8L * len, ByteOrder.LITTLE_ENDIAN); 351 ind.lanewise(VectorOperators.XOR, d).intoMemorySegment(out, inOff + 12L * len, ByteOrder.LITTLE_ENDIAN); 352 353 // increment counter 354 sd = sd.add(counterAdd); 355 } 356 } 357 358 public int implBlockSize() { 359 return numBlocks * 64; 360 } 361 } 362 363 private static byte[] hexStringToByteArray(String str) { 364 byte[] result = new byte[str.length() / 2]; 365 for (int i = 0; i < result.length; i++) { 366 result[i] = (byte) Character.digit(str.charAt(2 * i), 16); 367 result[i] <<= 4; 368 result[i] += Character.digit(str.charAt(2 * i + 1), 16); 369 } 370 return result; 371 } 372 373 private static void runKAT(ChaChaVector cc20, String keyStr, 374 String nonceStr, long counter, String inStr, String outStr) { 375 376 byte[] key = hexStringToByteArray(keyStr); 377 byte[] nonce = hexStringToByteArray(nonceStr); 378 byte[] in = hexStringToByteArray(inStr); 379 byte[] expOut = hexStringToByteArray(outStr); 380 381 // implementation only works at multiples of some size 382 int blockSize = cc20.implBlockSize(); 383 384 int length = blockSize * ((in.length + blockSize - 1) / blockSize); 385 in = Arrays.copyOf(in, length); 386 byte[] out = new byte[length]; 387 388 cc20.chacha20(key, nonce, counter, MemorySegment.ofArray(in), MemorySegment.ofArray(out)); 389 390 byte[] actOut = new byte[expOut.length]; 391 System.arraycopy(out, 0, actOut, 0, expOut.length); 392 393 if (!Arrays.equals(out, 0, expOut.length, expOut, 0, expOut.length)) { 394 throw new RuntimeException("Incorrect result"); 395 } 396 } 397 398 /* 399 * ChaCha20 Known Answer Tests to ensure that the implementation is correct. 400 */ 401 private static void runKAT(ChaChaVector cc20) { 402 runKAT(cc20, 403 "0000000000000000000000000000000000000000000000000000000000000001", 404 "000000000000000000000002", 405 1, 406 "416e79207375626d697373696f6e20746f20746865204945544620696e74656e" + 407 "6465642062792074686520436f6e7472696275746f7220666f72207075626c69" + 408 "636174696f6e20617320616c6c206f722070617274206f6620616e2049455446" + 409 "20496e7465726e65742d4472616674206f722052464320616e6420616e792073" + 410 "746174656d656e74206d6164652077697468696e2074686520636f6e74657874" + 411 "206f6620616e204945544620616374697669747920697320636f6e7369646572" + 412 "656420616e20224945544620436f6e747269627574696f6e222e205375636820" + 413 "73746174656d656e747320696e636c756465206f72616c2073746174656d656e" + 414 "747320696e20494554462073657373696f6e732c2061732077656c6c20617320" + 415 "7772697474656e20616e6420656c656374726f6e696320636f6d6d756e696361" + 416 "74696f6e73206d61646520617420616e792074696d65206f7220706c6163652c" + 417 "207768696368206172652061646472657373656420746f", 418 "a3fbf07df3fa2fde4f376ca23e82737041605d9f4f4f57bd8cff2c1d4b7955ec" + 419 "2a97948bd3722915c8f3d337f7d370050e9e96d647b7c39f56e031ca5eb6250d" + 420 "4042e02785ececfa4b4bb5e8ead0440e20b6e8db09d881a7c6132f420e527950" + 421 "42bdfa7773d8a9051447b3291ce1411c680465552aa6c405b7764d5e87bea85a" + 422 "d00f8449ed8f72d0d662ab052691ca66424bc86d2df80ea41f43abf937d3259d" + 423 "c4b2d0dfb48a6c9139ddd7f76966e928e635553ba76c5c879d7b35d49eb2e62b" + 424 "0871cdac638939e25e8a1e0ef9d5280fa8ca328b351c3c765989cbcf3daa8b6c" + 425 "cc3aaf9f3979c92b3720fc88dc95ed84a1be059c6499b9fda236e7e818b04b0b" + 426 "c39c1e876b193bfe5569753f88128cc08aaa9b63d1a16f80ef2554d7189c411f" + 427 "5869ca52c5b83fa36ff216b9c1d30062bebcfd2dc5bce0911934fda79a86f6e6" + 428 "98ced759c3ff9b6477338f3da4f9cd8514ea9982ccafb341b2384dd902f3d1ab" + 429 "7ac61dd29c6f21ba5b862f3730e37cfdc4fd806c22f221" 430 ); 431 } 432 }