1 /*
2 * Copyright (c) 2020, 2021, Microsoft Corporation. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 package org.openjdk.bench.jdk.incubator.vector.utf8;
25
26 import java.util.HashMap;
27 import java.util.Random;
28 import java.util.concurrent.TimeUnit;
29 import java.nio.Buffer;
30 import java.nio.ByteBuffer;
31 import java.nio.CharBuffer;
32 import java.nio.charset.Charset;
33 import java.nio.charset.CoderResult;
34
35 import org.openjdk.jmh.annotations.*;
36 import org.openjdk.jmh.infra.Blackhole;
37
38 import jdk.incubator.vector.*;
39
40 @BenchmarkMode(Mode.Throughput)
41 @OutputTimeUnit(TimeUnit.SECONDS)
42 @State(Scope.Thread)
43 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
44 @Warmup(iterations = 5, time = 3)
45 @Measurement(iterations = 8, time = 2)
46 public class DecodeBench {
47
48 @Param({"32768", "8388608"})
49 private int dataSize;
50
51 @Param({"1", "2", "3", "4"})
52 private int maxBytes;
53
54 private ByteBuffer src;
55 private CharBuffer dst;
56 private String in;
57 private String out;
58
59 private static final VectorSpecies<Byte> B128 = ByteVector.SPECIES_128;
60 private static final VectorSpecies<Short> S128 = ShortVector.SPECIES_128;
61 private static final VectorSpecies<Short> S256 = ShortVector.SPECIES_256;
62
63 private static final HashMap<Long, DecoderLutEntry> lutTable = new HashMap<Long, DecoderLutEntry>();
64
65 private static class DecoderLutEntry {
66 public final VectorShuffle<Byte> shufAB; // shuffling mask to get lower two bytes of symbols
67 public final VectorShuffle<Byte> shufC; // shuffling mask to get third bytes of symbols
68 public final byte srcStep; // number of bytes processed in input buffer
69 public final byte dstStep; // number of symbols produced in output buffer (doubled)
70 public final Vector<Byte> headerMask; // mask of "111..10" bits required in each byte
71 public final Vector<Short> zeroBits;
72
73 public DecoderLutEntry(VectorShuffle<Byte> _shufAB, VectorShuffle<Byte> _shufC,
74 byte _srcStep, byte _dstStep,
75 Vector<Byte> _headerMask, Vector<Short> _zeroBits) {
76 shufAB = _shufAB;
77 shufC = _shufC;
78 srcStep = _srcStep;
79 dstStep = _dstStep;
80 headerMask = _headerMask;
81 zeroBits = _zeroBits;
82 }
83
84 // @Override
85 // public String toString() {
86 // return String.format("shufAB = %s, shufC = %s, srcStep = %d, dstStep = %d, headerMask = %s, zeroBits = %s",
87 // arrayToString(shufAB), arrayToString(shufC), srcStep, dstStep, arrayToString(headerMask), arrayToString(zeroBits));
88 // }
89 }
90
91 @Setup(Level.Trial)
92 public void setupLutTable() {
93 int[] sizes = new int[32];
94 computeLutRecursive(sizes, 0, 0); //10609 entries total
95
96 // for (var entry : lutTable.entrySet()) {
97 // System.out.println("" + entry.getKey() + " -> " + entry.getValue());
98 // }
99 }
100
101 static void computeLutRecursive(int[] sizes, int num, int total) {
102 if (total >= 16) {
103 computeLutEntry(sizes, num);
104 return;
105 }
106 for (int size = 1; size <= 3; size++) {
107 sizes[num] = size;
108 computeLutRecursive(sizes, num + 1, total + size);
109 }
110 }
111
112 static void computeLutEntry(int[] sizes, int num) {
113 //find maximal number of chars to decode
114 int cnt = num - 1;
115 int preSum = 0;
116 for (int i = 0; i < cnt; i++)
117 preSum += sizes[i];
118 assert preSum < 16;
119 // Note: generally, we can process a char only if the next byte is within XMM register
120 // However, if the last char takes 3 bytes and fits the register tightly, we can take it too
121 if (preSum == 13 && preSum + sizes[cnt] == 16)
122 preSum += sizes[cnt++];
123 //still cannot process more that 8 chars per register
124 while (cnt > 8)
125 preSum -= sizes[--cnt];
126
127 //generate bitmask
128 long mask = 0;
129 for (int i = 0, pos = 0; i < num; i++) {
130 for (int j = 0; j < sizes[i]; j++, pos++) {
131 // The first byte is not represented in the mask
132 if (j > 0) {
133 mask |= 1 << pos;
134 }
135 }
136 }
137 assert mask <= 0xFFFF;
138
139 //generate shuffle masks
140 byte[] shufAB = new byte[16];
141 byte[] shufC = new byte[16];
142 for (int i = 0; i < 16; i++)
143 shufAB[i] = shufC[i] = (byte)0xFF;
144 for (int i = 0, pos = 0; i < cnt; i++) {
145 int sz = sizes[i];
146 for (int j = sz-1; j >= 0; j--, pos++) {
147 if (j < 2)
148 shufAB[2 * i + j] = (byte)pos;
149 else
150 shufC[2 * i] = (byte)pos;
151 }
152 }
153
154 //generate header masks for validation
155 byte[] headerMask = new byte[16];
156 for (int i = 0, pos = 0; i < cnt; i++) {
157 int sz = sizes[i];
158 for (int j = 0; j < sz; j++, pos++) {
159 int bits;
160 if (j > 0) bits = 2;
161 else if (sz == 1) bits = 1;
162 else if (sz == 2) bits = 3;
163 else /*sz == 3*/ bits = 4;
164 headerMask[pos] = (byte)-(1 << (8 - bits));
165 }
166 }
167
168 //generate min symbols values for validation
169 short[] zeroBits = new short[8];
170 for (int i = 0; i < 8; i++) {
171 int sz = i < cnt ? sizes[i] : 1;
172 if (sz == 1) zeroBits[i] = (short)(0xFF80);
173 else if (sz == 2) zeroBits[i] = (short)(0xF800);
174 else /*sz == 3*/ zeroBits[i] = (short)(0x0000);
175 }
176
177 //store info into the lookup table
178 lutTable.put(mask, new DecoderLutEntry(ByteVector.fromArray(B128, shufAB, 0).toShuffle(),
179 ByteVector.fromArray(B128, shufC, 0).toShuffle(),
180 (byte)preSum, (byte)cnt,
181 ByteVector.fromArray(B128, headerMask, 0),
182 ShortVector.fromArray(S128, zeroBits, 0)));
183 }
184
185 @Setup(Level.Trial)
186 public void setup() {
187 in = randomString(dataSize, maxBytes);
188 src = ByteBuffer.wrap(in.getBytes());
189 dst = CharBuffer.allocate(in.length());
190 }
191
192 @Setup(Level.Invocation)
193 public void setupInvocation() {
194 src.clear();
195 dst.clear();
196 }
197
198 @TearDown(Level.Invocation)
199 public void tearDownInvocation() {
200 out = new String(dst.array());
201 if (!in.equals(out)) {
202 System.out.println("in = (" + in.length() + ") \"" + arrayToString(in.getBytes()) + "\"");
203 System.out.println("out = (" + out.length() + ") \"" + arrayToString(out.getBytes()) + "\"");
204 throw new RuntimeException("Incorrect result");
205 }
206 }
207
208 private static final Random RANDOM = new Random(0);
209 private static int randomInt(int min /* inclusive */, int max /* inclusive */) {
210 return RANDOM.nextInt(max - min + 1) + min;
211 }
212 private static String randomString(int dataSize, int maxBytes) {
213 ByteBuffer buf = ByteBuffer.allocate(dataSize);
214 for (int i = 0, size = randomInt(1, maxBytes); i + size - 1 < dataSize; i += size, size = randomInt(1, maxBytes)) {
215 int b1, b2, b3, b4;
216 switch (size) {
217 case 1: {
218 b1 = randomInt(0x00, 0x7F);
219 buf.put(i + 0, (byte)((0b0 << (8 - 1)) | b1));
220 break;
221 }
222 case 2: {
223 b1 = randomInt(0xC2, 0xDF);
224 b2 = randomInt(0x80, 0xBF);
225 buf.put(i + 0, (byte)((0b110 << (8 - 3)) | b1));
226 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2));
227 break;
228 }
229 case 3: {
230 b1 = randomInt(0xE0, 0xEF);
231 switch (b1) {
232 case 0xE0:
233 b2 = randomInt(0xA0, 0xBF);
234 b3 = randomInt(0x80, 0xBF);
235 break;
236 default:
237 b2 = randomInt(0x80, 0xBF);
238 b3 = randomInt(0x80, 0xBF);
239 break;
240 }
241 buf.put(i + 0, (byte)((0b1110 << (8 - 4)) | b1));
242 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2));
243 buf.put(i + 2, (byte)((0b10 << (8 - 2)) | b3));
244 break;
245 }
246 case 4: {
247 b1 = randomInt(0xF0, 0xF4);
248 switch (b1) {
249 case 0xF0:
250 b2 = randomInt(0x90, 0xBF);
251 b3 = randomInt(0x80, 0xBF);
252 b4 = randomInt(0x80, 0xBF);
253 break;
254 case 0xF4:
255 b2 = randomInt(0x80, 0x8F);
256 b3 = randomInt(0x80, 0xBF);
257 b4 = randomInt(0x80, 0xBF);
258 break;
259 default:
260 b2 = randomInt(0x80, 0xBF);
261 b3 = randomInt(0x80, 0xBF);
262 b4 = randomInt(0x80, 0xBF);
263 break;
264 }
265 buf.put(i + 0, (byte)((0b11110 << (8 - 5)) | b1));
266 buf.put(i + 1, (byte)((0b10 << (8 - 2)) | b2));
267 buf.put(i + 2, (byte)((0b10 << (8 - 2)) | b3));
268 buf.put(i + 3, (byte)((0b10 << (8 - 2)) | b4));
269 break;
270 }
271 default:
272 throw new RuntimeException("not supported");
273 }
274 }
275 return new String(buf.array(), Charset.forName("UTF-8"));
276 }
277
278 private static String arrayToString(byte[] array) {
279 StringBuilder sb = new StringBuilder();
280 sb.append("[");
281 for (int i = 0; i < array.length; ++i) {
282 if (i != 0) sb.append(",");
283 sb.append(String.format("%x", (byte)array[i]));
284 }
285 sb.append("]");
286 return sb.toString();
287 }
288
289 @Benchmark
290 public void decodeScalar() {
291 decodeArrayLoop(src, dst);
292 }
293
294 @Benchmark
295 public void decodeVector() {
296 decodeArrayVectorized(src, dst);
297 decodeArrayLoop(src, dst);
298 }
299
300 @Benchmark
301 public void decodeVectorASCII() {
302 decodeArrayVectorizedASCII(src, dst);
303 decodeArrayLoop(src, dst);
304 }
305
306 private static void decodeArrayVectorized(ByteBuffer src, CharBuffer dst) {
307 // Algorithm is largely inspired from https://dirtyhandscoding.github.io/posts/utf8lut-vectorized-utf-8-converter-introduction.html
308
309 byte[] sa = src.array();
310 int sp = src.arrayOffset() + src.position();
311 int sl = src.arrayOffset() + src.limit();
312
313 char[] da = dst.array();
314 int dp = dst.arrayOffset() + dst.position();
315 int dl = dst.arrayOffset() + dst.limit();
316
317 // Vectorized loop
318 while (sp + B128.length() < sl && dp + S128.length() < dl) {
319 var bytes = ByteVector.fromArray(B128, sa, sp);
320
321 /* Decode */
322
323 var continuationByteMask = bytes.lanewise(VectorOperators.AND, (byte)0xC0).compare(VectorOperators.EQ, (byte)0x80);
324 final DecoderLutEntry lookup = lutTable.get(continuationByteMask.toLong());
325 if (lookup == null) {
326 break;
327 }
328 // Shuffle the 1st and 2nd bytes
329 var Rab = bytes.rearrange(lookup.shufAB, lookup.shufAB.toVector().compare(VectorOperators.NE, -1)).reinterpretAsShorts();
330 // Shuffle the 3rd byte
331 var Rc = bytes.rearrange(lookup.shufC, lookup.shufC.toVector().compare(VectorOperators.NE, -1)).reinterpretAsShorts();
332 // Extract the bits from each byte
333 var sum = Rab.lanewise(VectorOperators.AND, (short)0x007F)
334 .add(Rab.lanewise(VectorOperators.AND, (short)0x3F00).lanewise(VectorOperators.LSHR, 2))
335 .add(Rc.lanewise(VectorOperators.LSHL, 12));
336
337 /* Validate */
338
339 var zeroBits = lookup.zeroBits;
340 if (sum.lanewise(VectorOperators.AND, zeroBits).compare(VectorOperators.NE, 0).anyTrue()) {
341 break;
342 }
343 // Check for surrogate code point
344 if (sum.lanewise(VectorOperators.SUB, (short)0x6000).compare(VectorOperators.GT, 0x77FF).anyTrue()) {
345 break;
346 }
347 var headerMask = lookup.headerMask;
348 if (bytes.lanewise(VectorOperators.AND, headerMask).compare(VectorOperators.NE, headerMask.lanewise(VectorOperators.LSHL, 1)).anyTrue()) {
349 break;
350 }
351
352 /* Advance */
353
354 ((ShortVector)sum).intoCharArray(da, dp);
355 sp += lookup.srcStep;
356 dp += lookup.dstStep;
357 }
358
359 updatePositions(src, sp, dst, dp);
360 }
361
362 private static void decodeArrayVectorizedASCII(ByteBuffer src, CharBuffer dst) {
363 byte[] sa = src.array();
364 int sp = src.arrayOffset() + src.position();
365 int sl = src.arrayOffset() + src.limit();
366
367 char[] da = dst.array();
368 int dp = dst.arrayOffset() + dst.position();
369 int dl = dst.arrayOffset() + dst.limit();
370
371 // Vectorized loop
372 for (; sp <= sl - B128.length() && dp <= dl - S256.length(); sp += B128.length(), dp += S256.length()) {
373 var bytes = ByteVector.fromArray(B128, sa, sp);
374
375 if (bytes.compare(VectorOperators.LT, (byte) 0x00).anyTrue())
376 break;
377
378 ((ShortVector) bytes.convertShape(VectorOperators.B2S, S256, 0)).intoCharArray(da, dp);
379 }
380
381 updatePositions(src, sp, dst, dp);
382 }
383
384 private static CoderResult decodeArrayLoop(ByteBuffer src, CharBuffer dst) {
385 // This method is optimized for ASCII input.
386 byte[] sa = src.array();
387 int sp = src.arrayOffset() + src.position();
388 int sl = src.arrayOffset() + src.limit();
389
390 char[] da = dst.array();
391 int dp = dst.arrayOffset() + dst.position();
392 int dl = dst.arrayOffset() + dst.limit();
393 int dlASCII = dp + Math.min(sl - sp, dl - dp);
394
395 // ASCII only loop
396 while (dp < dlASCII && sa[sp] >= 0)
397 da[dp++] = (char) sa[sp++];
398 while (sp < sl) {
399 int b1 = sa[sp];
400 if (b1 >= 0) {
401 // 1 byte, 7 bits: 0xxxxxxx
402 if (dp >= dl)
403 return xflow(src, sp, sl, dst, dp, 1);
404 da[dp++] = (char) b1;
405 sp++;
406 } else if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) {
407 // 2 bytes, 11 bits: 110xxxxx 10xxxxxx
408 // [C2..DF] [80..BF]
409 if (sl - sp < 2 || dp >= dl)
410 return xflow(src, sp, sl, dst, dp, 2);
411 int b2 = sa[sp + 1];
412 // Now we check the first byte of 2-byte sequence as
413 // if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0)
414 // no longer need to check b1 against c1 & c0 for
415 // malformed as we did in previous version
416 // (b1 & 0x1e) == 0x0 || (b2 & 0xc0) != 0x80;
417 // only need to check the second byte b2.
418 if (isNotContinuation(b2))
419 return malformedForLength(src, sp, dst, dp, 1);
420 da[dp++] = (char) (((b1 << 6) ^ b2)
421 ^
422 (((byte) 0xC0 << 6) ^
423 ((byte) 0x80 << 0)));
424 sp += 2;
425 } else if ((b1 >> 4) == -2) {
426 // 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx
427 int srcRemaining = sl - sp;
428 if (srcRemaining < 3 || dp >= dl) {
429 if (srcRemaining > 1 && isMalformed3_2(b1, sa[sp + 1]))
430 return malformedForLength(src, sp, dst, dp, 1);
431 return xflow(src, sp, sl, dst, dp, 3);
432 }
433 int b2 = sa[sp + 1];
434 int b3 = sa[sp + 2];
435 if (isMalformed3(b1, b2, b3))
436 return malformed(src, sp, dst, dp, 3);
437 char c = (char)
438 ((b1 << 12) ^
439 (b2 << 6) ^
440 (b3 ^
441 (((byte) 0xE0 << 12) ^
442 ((byte) 0x80 << 6) ^
443 ((byte) 0x80 << 0))));
444 if (Character.isSurrogate(c))
445 return malformedForLength(src, sp, dst, dp, 3);
446 da[dp++] = c;
447 sp += 3;
448 } else if ((b1 >> 3) == -2) {
449 // 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
450 int srcRemaining = sl - sp;
451 if (srcRemaining < 4 || dl - dp < 2) {
452 b1 &= 0xff;
453 if (b1 > 0xf4 ||
454 srcRemaining > 1 && isMalformed4_2(b1, sa[sp + 1] & 0xff))
455 return malformedForLength(src, sp, dst, dp, 1);
456 if (srcRemaining > 2 && isMalformed4_3(sa[sp + 2]))
457 return malformedForLength(src, sp, dst, dp, 2);
458 return xflow(src, sp, sl, dst, dp, 4);
459 }
460 int b2 = sa[sp + 1];
461 int b3 = sa[sp + 2];
462 int b4 = sa[sp + 3];
463 int uc = ((b1 << 18) ^
464 (b2 << 12) ^
465 (b3 << 6) ^
466 (b4 ^
467 (((byte) 0xF0 << 18) ^
468 ((byte) 0x80 << 12) ^
469 ((byte) 0x80 << 6) ^
470 ((byte) 0x80 << 0))));
471 if (isMalformed4(b2, b3, b4) ||
472 // shortest form check
473 !Character.isSupplementaryCodePoint(uc)) {
474 return malformed(src, sp, dst, dp, 4);
475 }
476 da[dp++] = Character.highSurrogate(uc);
477 da[dp++] = Character.lowSurrogate(uc);
478 sp += 4;
479 } else
480 return malformed(src, sp, dst, dp, 1);
481 }
482 return xflow(src, sp, sl, dst, dp, 0);
483 }
484
485 private static CoderResult xflow(Buffer src, int sp, int sl,
486 Buffer dst, int dp, int nb) {
487 updatePositions(src, sp, dst, dp);
488 return (nb == 0 || sl - sp < nb)
489 ? CoderResult.UNDERFLOW : CoderResult.OVERFLOW;
490 }
491
492 private static CoderResult malformedForLength(ByteBuffer src,
493 int sp,
494 CharBuffer dst,
495 int dp,
496 int malformedNB)
497 {
498 updatePositions(src, sp, dst, dp);
499 return CoderResult.malformedForLength(malformedNB);
500 }
501
502 private static CoderResult malformed(ByteBuffer src, int sp,
503 CharBuffer dst, int dp,
504 int nb)
505 {
506 src.position(sp - src.arrayOffset());
507 CoderResult cr = malformedN(src, sp, nb);
508 updatePositions(src, sp, dst, dp);
509 return cr;
510 }
511
512 private static CoderResult malformedN(ByteBuffer src, int sp,
513 int nb) {
514 switch (nb) {
515 case 1:
516 case 2: // always 1
517 return CoderResult.malformedForLength(1);
518 case 3:
519 int b1 = src.get();
520 int b2 = src.get(); // no need to lookup b3
521 return CoderResult.malformedForLength(
522 ((b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
523 isNotContinuation(b2)) ? 1 : 2);
524 case 4: // we don't care the speed here
525 b1 = src.get() & 0xff;
526 b2 = src.get() & 0xff;
527 if (b1 > 0xf4 ||
528 (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) ||
529 (b1 == 0xf4 && (b2 & 0xf0) != 0x80) ||
530 isNotContinuation(b2))
531 return CoderResult.malformedForLength(1);
532 if (isNotContinuation(src.get()))
533 return CoderResult.malformedForLength(2);
534 return CoderResult.malformedForLength(3);
535 default:
536 assert false;
537 return null;
538 }
539 }
540
541 private static boolean isNotContinuation(int b) {
542 return (b & 0xc0) != 0x80;
543 }
544
545 // [E0] [A0..BF] [80..BF]
546 // [E1..EF] [80..BF] [80..BF]
547 private static boolean isMalformed3(int b1, int b2, int b3) {
548 return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
549 (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80;
550 }
551
552 // only used when there is only one byte left in src buffer
553 private static boolean isMalformed3_2(int b1, int b2) {
554 return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
555 (b2 & 0xc0) != 0x80;
556 }
557
558 // [F0] [90..BF] [80..BF] [80..BF]
559 // [F1..F3] [80..BF] [80..BF] [80..BF]
560 // [F4] [80..8F] [80..BF] [80..BF]
561 // only check 80-be range here, the [0xf0,0x80...] and [0xf4,0x90-...]
562 // will be checked by Character.isSupplementaryCodePoint(uc)
563 private static boolean isMalformed4(int b2, int b3, int b4) {
564 return (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80 ||
565 (b4 & 0xc0) != 0x80;
566 }
567
568 // only used when there is less than 4 bytes left in src buffer.
569 // both b1 and b2 should be "& 0xff" before passed in.
570 private static boolean isMalformed4_2(int b1, int b2) {
571 return (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) ||
572 (b1 == 0xf4 && (b2 & 0xf0) != 0x80) ||
573 (b2 & 0xc0) != 0x80;
574 }
575
576 // tests if b1 and b2 are malformed as the first 2 bytes of a
577 // legal`4-byte utf-8 byte sequence.
578 // only used when there is less than 4 bytes left in src buffer,
579 // after isMalformed4_2 has been invoked.
580 private static boolean isMalformed4_3(int b3) {
581 return (b3 & 0xc0) != 0x80;
582 }
583
584 private static void updatePositions(Buffer src, int sp,
585 Buffer dst, int dp) {
586 src.position(sp - src.arrayOffset());
587 dst.position(dp - dst.arrayOffset());
588 }
589 }