1 /* 2 * Copyright (c) 2020, 2021, Microsoft Corporation. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 package org.openjdk.bench.jdk.incubator.vector.utf8; 25 26 import java.util.HashMap; 27 import java.util.Random; 28 import java.util.concurrent.TimeUnit; 29 import java.nio.Buffer; 30 import java.nio.ByteBuffer; 31 import java.nio.CharBuffer; 32 import java.nio.charset.Charset; 33 import java.nio.charset.CoderResult; 34 35 import org.openjdk.jmh.annotations.*; 36 import org.openjdk.jmh.infra.Blackhole; 37 38 import jdk.incubator.vector.*; 39 40 @BenchmarkMode(Mode.Throughput) 41 @OutputTimeUnit(TimeUnit.SECONDS) 42 @State(Scope.Thread) 43 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) 44 @Warmup(iterations = 5, time = 3) 45 @Measurement(iterations = 8, time = 2) 46 public class DecodeBench { 47 48 @Param({"32768", "8388608"}) 49 private int dataSize; 50 51 @Param({"1", "2", "3", "4"}) 52 private int maxBytes; 53 54 private ByteBuffer src; 55 private CharBuffer dst; 56 private String in; 57 private String out; 58 59 private static final VectorSpecies<Byte> B128 = ByteVector.SPECIES_128; 60 private static final VectorSpecies<Short> S128 = ShortVector.SPECIES_128; 61 private static final VectorSpecies<Short> S256 = ShortVector.SPECIES_256; 62 63 private static final HashMap<Long, DecoderLutEntry> lutTable = new HashMap<Long, DecoderLutEntry>(); 64 65 private static class DecoderLutEntry { 66 public final VectorShuffle<Byte> shufAB; // shuffling mask to get lower two bytes of symbols 67 public final VectorShuffle<Byte> shufC; // shuffling mask to get third bytes of symbols 68 public final byte srcStep; // number of bytes processed in input buffer 69 public final byte dstStep; // number of symbols produced in output buffer (doubled) 70 public final Vector<Byte> headerMask; // mask of "111..10" bits required in each byte 71 public final Vector<Short> zeroBits; 72 73 public DecoderLutEntry(VectorShuffle<Byte> _shufAB, VectorShuffle<Byte> _shufC, 74 byte _srcStep, byte _dstStep, 75 Vector<Byte> _headerMask, Vector<Short> _zeroBits) { 76 shufAB = _shufAB; 77 shufC = _shufC; 78 srcStep = _srcStep; 79 dstStep = _dstStep; 80 headerMask = _headerMask; 81 zeroBits = _zeroBits; 82 } 83 84 // @Override 85 // public String toString() { 86 // return String.format("shufAB = %s, shufC = %s, srcStep = %d, dstStep = %d, headerMask = %s, zeroBits = %s", 87 // arrayToString(shufAB), arrayToString(shufC), srcStep, dstStep, arrayToString(headerMask), arrayToString(zeroBits)); 88 // } 89 } 90 91 @Setup(Level.Trial) 92 public void setupLutTable() { 93 int[] sizes = new int[32]; 94 computeLutRecursive(sizes, 0, 0); //10609 entries total 95 96 // for (var entry : lutTable.entrySet()) { 97 // System.out.println("" + entry.getKey() + " -> " + entry.getValue()); 98 // } 99 } 100 101 static void computeLutRecursive(int[] sizes, int num, int total) { 102 if (total >= 16) { 103 computeLutEntry(sizes, num); 104 return; 105 } 106 for (int size = 1; size <= 3; size++) { 107 sizes[num] = size; 108 computeLutRecursive(sizes, num + 1, total + size); 109 } 110 } 111 112 static void computeLutEntry(int[] sizes, int num) { 113 //find maximal number of chars to decode 114 int cnt = num - 1; 115 int preSum = 0; 116 for (int i = 0; i < cnt; i++) 117 preSum += sizes[i]; 118 assert preSum < 16; 119 // Note: generally, we can process a char only if the next byte is within XMM register 120 // However, if the last char takes 3 bytes and fits the register tightly, we can take it too 121 if (preSum == 13 && preSum + sizes[cnt] == 16) 122 preSum += sizes[cnt++]; 123 //still cannot process more that 8 chars per register 124 while (cnt > 8) 125 preSum -= sizes[--cnt]; 126 127 //generate bitmask 128 long mask = 0; 129 for (int i = 0, pos = 0; i < num; i++) { 130 for (int j = 0; j < sizes[i]; j++, pos++) { 131 // The first byte is not represented in the mask 132 if (j > 0) { 133 mask |= 1 << pos; 134 } 135 } 136 } 137 assert mask <= 0xFFFF; 138 139 //generate shuffle masks 140 byte[] shufAB = new byte[16]; 141 byte[] shufC = new byte[16]; 142 for (int i = 0; i < 16; i++) 143 shufAB[i] = shufC[i] = (byte)0xFF; 144 for (int i = 0, pos = 0; i < cnt; i++) { 145 int sz = sizes[i]; 146 for (int j = sz-1; j >= 0; j--, pos++) { 147 if (j < 2) 148 shufAB[2 * i + j] = (byte)pos; 149 else 150 shufC[2 * i] = (byte)pos; 151 } 152 } 153 154 //generate header masks for validation 155 byte[] headerMask = new byte[16]; 156 for (int i = 0, pos = 0; i < cnt; i++) { 157 int sz = sizes[i]; 158 for (int j = 0; j < sz; j++, pos++) { 159 int bits; 160 if (j > 0) bits = 2; 161 else if (sz == 1) bits = 1; 162 else if (sz == 2) bits = 3; 163 else /*sz == 3*/ bits = 4; 164 headerMask[pos] = (byte)-(1 << (8 - bits)); 165 } 166 } 167 168 //generate min symbols values for validation 169 short[] zeroBits = new short[8]; 170 for (int i = 0; i < 8; i++) { 171 int sz = i < cnt ? sizes[i] : 1; 172 if (sz == 1) zeroBits[i] = (short)(0xFF80); 173 else if (sz == 2) zeroBits[i] = (short)(0xF800); 174 else /*sz == 3*/ zeroBits[i] = (short)(0x0000); 175 } 176 177 //store info into the lookup table 178 lutTable.put(mask, new DecoderLutEntry(ByteVector.fromArray(B128, shufAB, 0).toShuffle(), 179 ByteVector.fromArray(B128, shufC, 0).toShuffle(), 180 (byte)preSum, (byte)cnt, 181 ByteVector.fromArray(B128, headerMask, 0), 182 ShortVector.fromArray(S128, zeroBits, 0))); 183 } 184 185 @Setup(Level.Trial) 186 public void setup() { 187 in = randomString(dataSize, maxBytes); 188 src = ByteBuffer.wrap(in.getBytes()); 189 dst = CharBuffer.allocate(in.length()); 190 } 191 192 @Setup(Level.Invocation) 193 public void setupInvocation() { 194 src.clear(); 195 dst.clear(); 196 } 197 198 @TearDown(Level.Invocation) 199 public void tearDownInvocation() { 200 out = new String(dst.array()); 201 if (!in.equals(out)) { 202 System.out.println("in = (" + in.length() + ") \"" + arrayToString(in.getBytes()) + "\""); 203 System.out.println("out = (" + out.length() + ") \"" + arrayToString(out.getBytes()) + "\""); 204 throw new RuntimeException("Incorrect result"); 205 } 206 } 207 208 private static final Random RANDOM = new Random(0); 209 private static int randomInt(int min /* inclusive */, int max /* inclusive */) { 210 return RANDOM.nextInt(max - min + 1) + min; 211 } 212 private static String randomString(int dataSize, int maxBytes) { 213 ByteBuffer buf = ByteBuffer.allocate(dataSize); 214 for (int i = 0, size = randomInt(1, maxBytes); i + size - 1 < dataSize; i += size, size = randomInt(1, maxBytes)) { 215 int b1, b2, b3, b4; 216 switch (size) { 217 case 1: { 218 b1 = randomInt(0x00, 0x7F); 219 buf.put(i + 0, (byte)((0b0 << (8 - 1)) | b1)); 220 break; 221 } 222 case 2: { 223 b1 = randomInt(0xC2, 0xDF); 224 b2 = randomInt(0x80, 0xBF); 225 buf.put(i + 0, (byte)((0b110 << (8 - 3)) | b1)); 226 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2)); 227 break; 228 } 229 case 3: { 230 b1 = randomInt(0xE0, 0xEF); 231 switch (b1) { 232 case 0xE0: 233 b2 = randomInt(0xA0, 0xBF); 234 b3 = randomInt(0x80, 0xBF); 235 break; 236 default: 237 b2 = randomInt(0x80, 0xBF); 238 b3 = randomInt(0x80, 0xBF); 239 break; 240 } 241 buf.put(i + 0, (byte)((0b1110 << (8 - 4)) | b1)); 242 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2)); 243 buf.put(i + 2, (byte)((0b10 << (8 - 2)) | b3)); 244 break; 245 } 246 case 4: { 247 b1 = randomInt(0xF0, 0xF4); 248 switch (b1) { 249 case 0xF0: 250 b2 = randomInt(0x90, 0xBF); 251 b3 = randomInt(0x80, 0xBF); 252 b4 = randomInt(0x80, 0xBF); 253 break; 254 case 0xF4: 255 b2 = randomInt(0x80, 0x8F); 256 b3 = randomInt(0x80, 0xBF); 257 b4 = randomInt(0x80, 0xBF); 258 break; 259 default: 260 b2 = randomInt(0x80, 0xBF); 261 b3 = randomInt(0x80, 0xBF); 262 b4 = randomInt(0x80, 0xBF); 263 break; 264 } 265 buf.put(i + 0, (byte)((0b11110 << (8 - 5)) | b1)); 266 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2)); 267 buf.put(i + 2, (byte)((0b10 << (8 - 2)) | b3)); 268 buf.put(i + 3, (byte)((0b10 << (8 - 2)) | b4)); 269 break; 270 } 271 default: 272 throw new RuntimeException("not supported"); 273 } 274 } 275 return new String(buf.array(), Charset.forName("UTF-8")); 276 } 277 278 private static String arrayToString(byte[] array) { 279 StringBuilder sb = new StringBuilder(); 280 sb.append("["); 281 for (int i = 0; i < array.length; ++i) { 282 if (i != 0) sb.append(","); 283 sb.append(String.format("%x", (byte)array[i])); 284 } 285 sb.append("]"); 286 return sb.toString(); 287 } 288 289 @Benchmark 290 public void decodeScalar() { 291 decodeArrayLoop(src, dst); 292 } 293 294 @Benchmark 295 public void decodeVector() { 296 decodeArrayVectorized(src, dst); 297 decodeArrayLoop(src, dst); 298 } 299 300 @Benchmark 301 public void decodeVectorASCII() { 302 decodeArrayVectorizedASCII(src, dst); 303 decodeArrayLoop(src, dst); 304 } 305 306 private static void decodeArrayVectorized(ByteBuffer src, CharBuffer dst) { 307 // Algorithm is largely inspired from https://dirtyhandscoding.github.io/posts/utf8lut-vectorized-utf-8-converter-introduction.html 308 309 byte[] sa = src.array(); 310 int sp = src.arrayOffset() + src.position(); 311 int sl = src.arrayOffset() + src.limit(); 312 313 char[] da = dst.array(); 314 int dp = dst.arrayOffset() + dst.position(); 315 int dl = dst.arrayOffset() + dst.limit(); 316 317 // Vectorized loop 318 while (sp + B128.length() < sl && dp + S128.length() < dl) { 319 var bytes = ByteVector.fromArray(B128, sa, sp); 320 321 /* Decode */ 322 323 var continuationByteMask = bytes.lanewise(VectorOperators.AND, (byte)0xC0).compare(VectorOperators.EQ, (byte)0x80); 324 final DecoderLutEntry lookup = lutTable.get(continuationByteMask.toLong()); 325 if (lookup == null) { 326 break; 327 } 328 // Shuffle the 1st and 2nd bytes 329 var Rab = bytes.rearrange(lookup.shufAB, lookup.shufAB.toVector().compare(VectorOperators.NE, -1)).reinterpretAsShorts(); 330 // Shuffle the 3rd byte 331 var Rc = bytes.rearrange(lookup.shufC, lookup.shufC.toVector().compare(VectorOperators.NE, -1)).reinterpretAsShorts(); 332 // Extract the bits from each byte 333 var sum = Rab.lanewise(VectorOperators.AND, (short)0x007F) 334 .add(Rab.lanewise(VectorOperators.AND, (short)0x3F00).lanewise(VectorOperators.LSHR, 2)) 335 .add(Rc.lanewise(VectorOperators.LSHL, 12)); 336 337 /* Validate */ 338 339 var zeroBits = lookup.zeroBits; 340 if (sum.lanewise(VectorOperators.AND, zeroBits).compare(VectorOperators.NE, 0).anyTrue()) { 341 break; 342 } 343 // Check for surrogate code point 344 if (sum.lanewise(VectorOperators.SUB, (short)0x6000).compare(VectorOperators.GT, 0x77FF).anyTrue()) { 345 break; 346 } 347 var headerMask = lookup.headerMask; 348 if (bytes.lanewise(VectorOperators.AND, headerMask).compare(VectorOperators.NE, headerMask.lanewise(VectorOperators.LSHL, 1)).anyTrue()) { 349 break; 350 } 351 352 /* Advance */ 353 354 ((ShortVector)sum).intoCharArray(da, dp); 355 sp += lookup.srcStep; 356 dp += lookup.dstStep; 357 } 358 359 updatePositions(src, sp, dst, dp); 360 } 361 362 private static void decodeArrayVectorizedASCII(ByteBuffer src, CharBuffer dst) { 363 byte[] sa = src.array(); 364 int sp = src.arrayOffset() + src.position(); 365 int sl = src.arrayOffset() + src.limit(); 366 367 char[] da = dst.array(); 368 int dp = dst.arrayOffset() + dst.position(); 369 int dl = dst.arrayOffset() + dst.limit(); 370 371 // Vectorized loop 372 for (; sp <= sl - B128.length() && dp <= dl - S256.length(); sp += B128.length(), dp += S256.length()) { 373 var bytes = ByteVector.fromArray(B128, sa, sp); 374 375 if (bytes.compare(VectorOperators.LT, (byte) 0x00).anyTrue()) 376 break; 377 378 ((ShortVector) bytes.convertShape(VectorOperators.B2S, S256, 0)).intoCharArray(da, dp); 379 } 380 381 updatePositions(src, sp, dst, dp); 382 } 383 384 private static CoderResult decodeArrayLoop(ByteBuffer src, CharBuffer dst) { 385 // This method is optimized for ASCII input. 386 byte[] sa = src.array(); 387 int sp = src.arrayOffset() + src.position(); 388 int sl = src.arrayOffset() + src.limit(); 389 390 char[] da = dst.array(); 391 int dp = dst.arrayOffset() + dst.position(); 392 int dl = dst.arrayOffset() + dst.limit(); 393 int dlASCII = dp + Math.min(sl - sp, dl - dp); 394 395 // ASCII only loop 396 while (dp < dlASCII && sa[sp] >= 0) 397 da[dp++] = (char) sa[sp++]; 398 while (sp < sl) { 399 int b1 = sa[sp]; 400 if (b1 >= 0) { 401 // 1 byte, 7 bits: 0xxxxxxx 402 if (dp >= dl) 403 return xflow(src, sp, sl, dst, dp, 1); 404 da[dp++] = (char) b1; 405 sp++; 406 } else if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) { 407 // 2 bytes, 11 bits: 110xxxxx 10xxxxxx 408 // [C2..DF] [80..BF] 409 if (sl - sp < 2 || dp >= dl) 410 return xflow(src, sp, sl, dst, dp, 2); 411 int b2 = sa[sp + 1]; 412 // Now we check the first byte of 2-byte sequence as 413 // if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) 414 // no longer need to check b1 against c1 & c0 for 415 // malformed as we did in previous version 416 // (b1 & 0x1e) == 0x0 || (b2 & 0xc0) != 0x80; 417 // only need to check the second byte b2. 418 if (isNotContinuation(b2)) 419 return malformedForLength(src, sp, dst, dp, 1); 420 da[dp++] = (char) (((b1 << 6) ^ b2) 421 ^ 422 (((byte) 0xC0 << 6) ^ 423 ((byte) 0x80 << 0))); 424 sp += 2; 425 } else if ((b1 >> 4) == -2) { 426 // 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx 427 int srcRemaining = sl - sp; 428 if (srcRemaining < 3 || dp >= dl) { 429 if (srcRemaining > 1 && isMalformed3_2(b1, sa[sp + 1])) 430 return malformedForLength(src, sp, dst, dp, 1); 431 return xflow(src, sp, sl, dst, dp, 3); 432 } 433 int b2 = sa[sp + 1]; 434 int b3 = sa[sp + 2]; 435 if (isMalformed3(b1, b2, b3)) 436 return malformed(src, sp, dst, dp, 3); 437 char c = (char) 438 ((b1 << 12) ^ 439 (b2 << 6) ^ 440 (b3 ^ 441 (((byte) 0xE0 << 12) ^ 442 ((byte) 0x80 << 6) ^ 443 ((byte) 0x80 << 0)))); 444 if (Character.isSurrogate(c)) 445 return malformedForLength(src, sp, dst, dp, 3); 446 da[dp++] = c; 447 sp += 3; 448 } else if ((b1 >> 3) == -2) { 449 // 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 450 int srcRemaining = sl - sp; 451 if (srcRemaining < 4 || dl - dp < 2) { 452 b1 &= 0xff; 453 if (b1 > 0xf4 || 454 srcRemaining > 1 && isMalformed4_2(b1, sa[sp + 1] & 0xff)) 455 return malformedForLength(src, sp, dst, dp, 1); 456 if (srcRemaining > 2 && isMalformed4_3(sa[sp + 2])) 457 return malformedForLength(src, sp, dst, dp, 2); 458 return xflow(src, sp, sl, dst, dp, 4); 459 } 460 int b2 = sa[sp + 1]; 461 int b3 = sa[sp + 2]; 462 int b4 = sa[sp + 3]; 463 int uc = ((b1 << 18) ^ 464 (b2 << 12) ^ 465 (b3 << 6) ^ 466 (b4 ^ 467 (((byte) 0xF0 << 18) ^ 468 ((byte) 0x80 << 12) ^ 469 ((byte) 0x80 << 6) ^ 470 ((byte) 0x80 << 0)))); 471 if (isMalformed4(b2, b3, b4) || 472 // shortest form check 473 !Character.isSupplementaryCodePoint(uc)) { 474 return malformed(src, sp, dst, dp, 4); 475 } 476 da[dp++] = Character.highSurrogate(uc); 477 da[dp++] = Character.lowSurrogate(uc); 478 sp += 4; 479 } else 480 return malformed(src, sp, dst, dp, 1); 481 } 482 return xflow(src, sp, sl, dst, dp, 0); 483 } 484 485 private static CoderResult xflow(Buffer src, int sp, int sl, 486 Buffer dst, int dp, int nb) { 487 updatePositions(src, sp, dst, dp); 488 return (nb == 0 || sl - sp < nb) 489 ? CoderResult.UNDERFLOW : CoderResult.OVERFLOW; 490 } 491 492 private static CoderResult malformedForLength(ByteBuffer src, 493 int sp, 494 CharBuffer dst, 495 int dp, 496 int malformedNB) 497 { 498 updatePositions(src, sp, dst, dp); 499 return CoderResult.malformedForLength(malformedNB); 500 } 501 502 private static CoderResult malformed(ByteBuffer src, int sp, 503 CharBuffer dst, int dp, 504 int nb) 505 { 506 src.position(sp - src.arrayOffset()); 507 CoderResult cr = malformedN(src, sp, nb); 508 updatePositions(src, sp, dst, dp); 509 return cr; 510 } 511 512 private static CoderResult malformedN(ByteBuffer src, int sp, 513 int nb) { 514 switch (nb) { 515 case 1: 516 case 2: // always 1 517 return CoderResult.malformedForLength(1); 518 case 3: 519 int b1 = src.get(); 520 int b2 = src.get(); // no need to lookup b3 521 return CoderResult.malformedForLength( 522 ((b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || 523 isNotContinuation(b2)) ? 1 : 2); 524 case 4: // we don't care the speed here 525 b1 = src.get() & 0xff; 526 b2 = src.get() & 0xff; 527 if (b1 > 0xf4 || 528 (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) || 529 (b1 == 0xf4 && (b2 & 0xf0) != 0x80) || 530 isNotContinuation(b2)) 531 return CoderResult.malformedForLength(1); 532 if (isNotContinuation(src.get())) 533 return CoderResult.malformedForLength(2); 534 return CoderResult.malformedForLength(3); 535 default: 536 assert false; 537 return null; 538 } 539 } 540 541 private static boolean isNotContinuation(int b) { 542 return (b & 0xc0) != 0x80; 543 } 544 545 // [E0] [A0..BF] [80..BF] 546 // [E1..EF] [80..BF] [80..BF] 547 private static boolean isMalformed3(int b1, int b2, int b3) { 548 return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || 549 (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80; 550 } 551 552 // only used when there is only one byte left in src buffer 553 private static boolean isMalformed3_2(int b1, int b2) { 554 return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || 555 (b2 & 0xc0) != 0x80; 556 } 557 558 // [F0] [90..BF] [80..BF] [80..BF] 559 // [F1..F3] [80..BF] [80..BF] [80..BF] 560 // [F4] [80..8F] [80..BF] [80..BF] 561 // only check 80-be range here, the [0xf0,0x80...] and [0xf4,0x90-...] 562 // will be checked by Character.isSupplementaryCodePoint(uc) 563 private static boolean isMalformed4(int b2, int b3, int b4) { 564 return (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80 || 565 (b4 & 0xc0) != 0x80; 566 } 567 568 // only used when there is less than 4 bytes left in src buffer. 569 // both b1 and b2 should be "& 0xff" before passed in. 570 private static boolean isMalformed4_2(int b1, int b2) { 571 return (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) || 572 (b1 == 0xf4 && (b2 & 0xf0) != 0x80) || 573 (b2 & 0xc0) != 0x80; 574 } 575 576 // tests if b1 and b2 are malformed as the first 2 bytes of a 577 // legal`4-byte utf-8 byte sequence. 578 // only used when there is less than 4 bytes left in src buffer, 579 // after isMalformed4_2 has been invoked. 580 private static boolean isMalformed4_3(int b3) { 581 return (b3 & 0xc0) != 0x80; 582 } 583 584 private static void updatePositions(Buffer src, int sp, 585 Buffer dst, int dp) { 586 src.position(sp - src.arrayOffset()); 587 dst.position(dp - dst.arrayOffset()); 588 } 589 }