1 /*
   2  * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 import java.io.Serializable;
  27 import java.lang.invoke.CallSite;
  28 import java.lang.invoke.MethodHandle;
  29 import java.lang.invoke.MethodHandles;
  30 import java.lang.invoke.MethodType;
  31 import java.lang.runtime.SwitchBootstraps;
  32 import java.util.ArrayList;
  33 import java.util.List;
  34 import java.util.Map;
  35 import java.util.Random;
  36 import java.util.Set;
  37 import java.util.stream.Collectors;
  38 import java.util.stream.IntStream;
  39 import java.util.stream.Stream;
  40 import jdk.test.lib.RandomFactory;
  41 
  42 import org.testng.annotations.Test;
  43 
  44 import static org.testng.Assert.assertEquals;
  45 import static org.testng.Assert.fail;
  46 
  47 /**
  48  * @test
  49  * @key randomness
  50  * @library /test/lib
  51  * @build jdk.test.lib.RandomFactory
  52  * @run testng SwitchBootstrapsTest
  53  */
  54 @Test
  55 public class SwitchBootstrapsTest {
  56     private final static Set<Class<?>> BOOLEAN_TYPES = Set.of(boolean.class, Boolean.class);
  57     private final static Set<Class<?>> ALL_INT_TYPES = Set.of(int.class, short.class, byte.class, char.class,
  58                                                               Integer.class, Short.class, Byte.class, Character.class);
  59     private final static Set<Class<?>> SIGNED_NON_BYTE_TYPES = Set.of(int.class, Integer.class, short.class, Short.class);
  60     private final static Set<Class<?>> CHAR_TYPES = Set.of(char.class, Character.class);
  61     private final static Set<Class<?>> BYTE_TYPES = Set.of(byte.class, Byte.class);
  62     private final static Set<Class<?>> SIGNED_TYPES
  63             = Set.of(int.class, short.class, byte.class,
  64                      Integer.class, Short.class, Byte.class);
  65 
  66     public static final MethodHandle BSM_BOOLEAN_SWITCH;
  67     public static final MethodHandle BSM_INT_SWITCH;
  68     public static final MethodHandle BSM_LONG_SWITCH;
  69     public static final MethodHandle BSM_FLOAT_SWITCH;
  70     public static final MethodHandle BSM_DOUBLE_SWITCH;
  71     public static final MethodHandle BSM_STRING_SWITCH;
  72     public static final MethodHandle BSM_ENUM_SWITCH;
  73     public static final MethodHandle BSM_TYPE_SWITCH;
  74 
  75     private final static Random random = RandomFactory.getRandom();
  76 
  77     static {
  78         try {
  79             BSM_BOOLEAN_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "booleanSwitch",
  80                                                                    MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, boolean[].class));
  81             BSM_INT_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "intSwitch",
  82                                                                MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, int[].class));
  83             BSM_LONG_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "longSwitch",
  84                                                                 MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, long[].class));
  85             BSM_FLOAT_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "floatSwitch",
  86                                                                  MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, float[].class));
  87             BSM_DOUBLE_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "doubleSwitch",
  88                                                                   MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, double[].class));
  89             BSM_STRING_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "stringSwitch",
  90                                                                   MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, String[].class));
  91             BSM_ENUM_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "enumSwitch",
  92                                                                 MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, Class.class, String[].class));
  93             BSM_TYPE_SWITCH = MethodHandles.lookup().findStatic(SwitchBootstraps.class, "typeSwitch",
  94                                                                 MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, Class[].class));
  95         }
  96         catch (NoSuchMethodException | IllegalAccessException e) {
  97             throw new RuntimeException(e);
  98         }
  99     }
 100 
 101     private MethodType switchType(Class<?> target) {
 102         return MethodType.methodType(int.class, target);
 103     }
 104 
 105     private Object box(Class<?> clazz, int i) {
 106         if (clazz == Integer.class)
 107             return i;
 108         else if (clazz == Short.class)
 109             return (short) i;
 110         else if (clazz == Character.class)
 111             return (char) i;
 112         else if (clazz == Byte.class)
 113             return (byte) i;
 114         else
 115             throw new IllegalArgumentException(clazz.toString());
 116     }
 117 
 118     private void testBoolean(boolean... labels) throws Throwable {
 119         Map<Class<?>, MethodHandle> mhs
 120                 = Map.of(boolean.class, ((CallSite) BSM_BOOLEAN_SWITCH.invoke(MethodHandles.lookup(), "", switchType(boolean.class), labels)).dynamicInvoker(),
 121                          Boolean.class, ((CallSite) BSM_BOOLEAN_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Boolean.class), labels)).dynamicInvoker());
 122 
 123         List<Boolean> labelList = new ArrayList<>();
 124         for (boolean label : labels)
 125             labelList.add(label);
 126 
 127         for (int i=0; i<labels.length; i++) {
 128             assertEquals(i, (int) mhs.get(boolean.class).invokeExact((boolean) labels[i]));
 129             assertEquals(i, (int) mhs.get(Boolean.class).invokeExact((Boolean) labels[i]));
 130         }
 131 
 132         boolean[] booleans = { false, true };
 133         for (boolean b : booleans) {
 134             if (!labelList.contains(b)) {
 135                 assertEquals(labels.length, mhs.get(boolean.class).invoke((boolean) b));
 136                 assertEquals(labels.length, mhs.get(Boolean.class).invoke((boolean) b));
 137             }
 138         }
 139 
 140         assertEquals(-1, (int) mhs.get(Boolean.class).invoke(null));
 141     }
 142 
 143     private void testInt(Set<Class<?>> targetTypes, int... labels) throws Throwable {
 144         Map<Class<?>, MethodHandle> mhs
 145                 = Map.of(char.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(char.class), labels)).dynamicInvoker(),
 146                          byte.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(byte.class), labels)).dynamicInvoker(),
 147                          short.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(short.class), labels)).dynamicInvoker(),
 148                          int.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(int.class), labels)).dynamicInvoker(),
 149                          Character.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Character.class), labels)).dynamicInvoker(),
 150                          Byte.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Byte.class), labels)).dynamicInvoker(),
 151                          Short.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Short.class), labels)).dynamicInvoker(),
 152                          Integer.class, ((CallSite) BSM_INT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Integer.class), labels)).dynamicInvoker());
 153 
 154         List<Integer> labelList = IntStream.of(labels)
 155                                            .boxed()
 156                                            .collect(Collectors.toList());
 157 
 158         for (int i=0; i<labels.length; i++) {
 159             // test with invokeExact
 160             if (targetTypes.contains(char.class))
 161                 assertEquals(i, (int) mhs.get(char.class).invokeExact((char) labels[i]));
 162             if (targetTypes.contains(byte.class))
 163                 assertEquals(i, (int) mhs.get(byte.class).invokeExact((byte) labels[i]));
 164             if (targetTypes.contains(short.class))
 165                 assertEquals(i, (int) mhs.get(short.class).invokeExact((short) labels[i]));
 166             if (targetTypes.contains(int.class))
 167                 assertEquals(i, (int) mhs.get(int.class).invokeExact(labels[i]));
 168             if (targetTypes.contains(Integer.class))
 169                 assertEquals(i, (int) mhs.get(Integer.class).invokeExact((Integer) labels[i]));
 170             if (targetTypes.contains(Short.class))
 171                 assertEquals(i, (int) mhs.get(Short.class).invokeExact((Short) (short) labels[i]));
 172             if (targetTypes.contains(Byte.class))
 173                 assertEquals(i, (int) mhs.get(Byte.class).invokeExact((Byte) (byte) labels[i]));
 174             if (targetTypes.contains(Character.class))
 175                 assertEquals(i, (int) mhs.get(Character.class).invokeExact((Character) (char) labels[i]));
 176 
 177             // and with invoke
 178             assertEquals(i, (int) mhs.get(int.class).invoke(labels[i]));
 179             assertEquals(i, (int) mhs.get(Integer.class).invoke(labels[i]));
 180         }
 181 
 182         for (int i=-1000; i<1000; i++) {
 183             if (!labelList.contains(i)) {
 184                 assertEquals(labels.length, mhs.get(short.class).invoke((short) i));
 185                 assertEquals(labels.length, mhs.get(Short.class).invoke((short) i));
 186                 assertEquals(labels.length, mhs.get(int.class).invoke(i));
 187                 assertEquals(labels.length, mhs.get(Integer.class).invoke(i));
 188                 if (i >= 0) {
 189                     assertEquals(labels.length, mhs.get(char.class).invoke((char)i));
 190                     assertEquals(labels.length, mhs.get(Character.class).invoke((char)i));
 191                 }
 192                 if (i >= -128 && i <= 127) {
 193                     assertEquals(labels.length, mhs.get(byte.class).invoke((byte)i));
 194                     assertEquals(labels.length, mhs.get(Byte.class).invoke((byte)i));
 195                 }
 196             }
 197         }
 198 
 199         assertEquals(-1, (int) mhs.get(Integer.class).invoke(null));
 200         assertEquals(-1, (int) mhs.get(Short.class).invoke(null));
 201         assertEquals(-1, (int) mhs.get(Byte.class).invoke(null));
 202         assertEquals(-1, (int) mhs.get(Character.class).invoke(null));
 203     }
 204 
 205     private void testFloat(float... labels) throws Throwable {
 206         Map<Class<?>, MethodHandle> mhs
 207                 = Map.of(float.class, ((CallSite) BSM_FLOAT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(float.class), labels)).dynamicInvoker(),
 208                          Float.class, ((CallSite) BSM_FLOAT_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Float.class), labels)).dynamicInvoker());
 209 
 210         List<Float> labelList = new ArrayList<>();
 211         for (float label : labels)
 212             labelList.add(label);
 213 
 214         for (int i=0; i<labels.length; i++) {
 215             assertEquals(i, (int) mhs.get(float.class).invokeExact((float) labels[i]));
 216             assertEquals(i, (int) mhs.get(Float.class).invokeExact((Float) labels[i]));
 217         }
 218 
 219         float[] someFloats = { 1.0f, Float.MIN_VALUE, 3.14f };
 220         for (float f : someFloats) {
 221             if (!labelList.contains(f)) {
 222                 assertEquals(labels.length, mhs.get(float.class).invoke((float) f));
 223                 assertEquals(labels.length, mhs.get(Float.class).invoke((float) f));
 224             }
 225         }
 226 
 227         assertEquals(-1, (int) mhs.get(Float.class).invoke(null));
 228     }
 229 
 230     private void testDouble(double... labels) throws Throwable {
 231         Map<Class<?>, MethodHandle> mhs
 232                 = Map.of(double.class, ((CallSite) BSM_DOUBLE_SWITCH.invoke(MethodHandles.lookup(), "", switchType(double.class), labels)).dynamicInvoker(),
 233                          Double.class, ((CallSite) BSM_DOUBLE_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Double.class), labels)).dynamicInvoker());
 234 
 235         var labelList = new ArrayList<Double>();
 236         for (double label : labels)
 237             labelList.add(label);
 238 
 239         for (int i=0; i<labels.length; i++) {
 240             assertEquals(i, (int) mhs.get(double.class).invokeExact((double) labels[i]));
 241             assertEquals(i, (int) mhs.get(Double.class).invokeExact((Double) labels[i]));
 242         }
 243 
 244         double[] someDoubles = { 1.0, Double.MIN_VALUE, 3.14 };
 245         for (double f : someDoubles) {
 246             if (!labelList.contains(f)) {
 247                 assertEquals(labels.length, mhs.get(double.class).invoke((double) f));
 248                 assertEquals(labels.length, mhs.get(Double.class).invoke((double) f));
 249             }
 250         }
 251 
 252         assertEquals(-1, (int) mhs.get(Double.class).invoke(null));
 253     }
 254 
 255     private void testLong(long... labels) throws Throwable {
 256         Map<Class<?>, MethodHandle> mhs
 257                 = Map.of(long.class, ((CallSite) BSM_LONG_SWITCH.invoke(MethodHandles.lookup(), "", switchType(long.class), labels)).dynamicInvoker(),
 258                          Long.class, ((CallSite) BSM_LONG_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Long.class), labels)).dynamicInvoker());
 259 
 260         List<Long> labelList = new ArrayList<>();
 261         for (long label : labels)
 262             labelList.add(label);
 263 
 264         for (int i=0; i<labels.length; i++) {
 265             assertEquals(i, (int) mhs.get(long.class).invokeExact((long) labels[i]));
 266             assertEquals(i, (int) mhs.get(Long.class).invokeExact((Long) labels[i]));
 267         }
 268 
 269         long[] someLongs = { 1L, Long.MIN_VALUE, Long.MAX_VALUE };
 270         for (long l : someLongs) {
 271             if (!labelList.contains(l)) {
 272                 assertEquals(labels.length, mhs.get(long.class).invoke((long) l));
 273                 assertEquals(labels.length, mhs.get(Long.class).invoke((long) l));
 274             }
 275         }
 276 
 277         assertEquals(-1, (int) mhs.get(Long.class).invoke(null));
 278     }
 279 
 280     private void testString(String... targets) throws Throwable {
 281         MethodHandle indy = ((CallSite) BSM_STRING_SWITCH.invoke(MethodHandles.lookup(), "", switchType(String.class), targets)).dynamicInvoker();
 282         List<String> targetList = Stream.of(targets)
 283                                         .collect(Collectors.toList());
 284 
 285         for (int i=0; i<targets.length; i++) {
 286             String s = targets[i];
 287             int result = (int) indy.invoke(s);
 288             assertEquals((s == null) ? -1 : i, result);
 289         }
 290 
 291         for (String s : List.of("", "A", "AA", "AAA", "AAAA")) {
 292             if (!targetList.contains(s)) {
 293                 assertEquals(targets.length, indy.invoke(s));
 294             }
 295         }
 296         assertEquals(-1, (int) indy.invoke(null));
 297     }
 298 
 299     private<E extends Enum<E>> void testEnum(Class<E> enumClass, String... targets) throws Throwable {
 300         MethodHandle indy = ((CallSite) BSM_ENUM_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Enum.class), enumClass, targets)).dynamicInvoker();
 301         List<E> targetList = Stream.of(targets)
 302                                    .map(s -> Enum.valueOf(enumClass, s))
 303                                    .collect(Collectors.toList());
 304 
 305         for (int i=0; i<targets.length; i++) {
 306             String s = targets[i];
 307             E e = Enum.valueOf(enumClass, s);
 308             int result = (int) indy.invoke(e);
 309             assertEquals((s == null) ? -1 : i, result);
 310         }
 311 
 312         for (E e : enumClass.getEnumConstants()) {
 313             int index = (int) indy.invoke(e);
 314             if (targetList.contains(e))
 315                 assertEquals(e.name(), targets[index]);
 316             else
 317                 assertEquals(targets.length, index);
 318         }
 319 
 320         assertEquals(-1, (int) indy.invoke(null));
 321     }
 322 
 323     public void testBoolean() throws Throwable {
 324         testBoolean(new boolean[0]);
 325         testBoolean(false);
 326         testBoolean(true);
 327         testBoolean(false, true);
 328     }
 329 
 330     public void testInt() throws Throwable {
 331         testInt(ALL_INT_TYPES, 8, 6, 7, 5, 3, 0, 9);
 332         testInt(ALL_INT_TYPES, 1, 2, 4, 8, 16);
 333         testInt(ALL_INT_TYPES, 5, 4, 3, 2, 1, 0);
 334         testInt(SIGNED_TYPES, 5, 4, 3, 2, 1, 0, -1);
 335         testInt(SIGNED_TYPES, -1);
 336         testInt(ALL_INT_TYPES, new int[] { });
 337 
 338         for (int i=0; i<5; i++) {
 339             int len = 50 + random.nextInt(800);
 340             int[] arr = IntStream.generate(() -> random.nextInt(10000) - 5000)
 341                                  .distinct()
 342                                  .limit(len)
 343                                  .toArray();
 344             testInt(SIGNED_NON_BYTE_TYPES, arr);
 345 
 346             arr = IntStream.generate(() -> random.nextInt(10000))
 347                     .distinct()
 348                     .limit(len)
 349                     .toArray();
 350             testInt(CHAR_TYPES, arr);
 351 
 352             arr = IntStream.generate(() -> random.nextInt(127) - 64)
 353                            .distinct()
 354                            .limit(120)
 355                            .toArray();
 356             testInt(BYTE_TYPES, arr);
 357         }
 358     }
 359 
 360     public void testLong() throws Throwable {
 361         testLong(1L, Long.MIN_VALUE, Long.MAX_VALUE);
 362         testLong(8L, 2L, 5L, 4L, 3L, 9L, 1L);
 363         testLong(new long[] { });
 364 
 365         // @@@ Random tests
 366         // @@@ More tests for weird values
 367     }
 368 
 369     public void testFloat() throws Throwable {
 370         testFloat(0.0f, -0.0f, -1.0f, 1.0f, 3.14f, Float.MIN_VALUE, Float.MAX_VALUE, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
 371         testFloat(new float[] { });
 372         testFloat(0.0f, 1.0f, 3.14f, Float.NaN);
 373 
 374         // @@@ Random tests
 375         // @@@ More tests for weird values
 376     }
 377 
 378     public void testDouble() throws Throwable {
 379         testDouble(0.0, -0.0, -1.0, 1.0, 3.14, Double.MIN_VALUE, Double.MAX_VALUE,
 380                    Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
 381         testDouble(new double[] { });
 382         testDouble(0.0f, 1.0f, 3.14f, Double.NaN);
 383 
 384         // @@@ Random tests
 385         // @@@ More tests for weird values
 386     }
 387 
 388     public void testString() throws Throwable {
 389         testString("a", "b", "c");
 390         testString("c", "b", "a");
 391         testString("cow", "pig", "horse", "orangutan", "elephant", "dog", "frog", "ant");
 392         testString("a", "b", "c", "A", "B", "C");
 393         testString("C", "B", "A", "c", "b", "a");
 394 
 395         // Tests with hash collisions; Ba/CB, Ca/DB
 396         testString("Ba", "CB");
 397         testString("Ba", "CB", "Ca", "DB");
 398 
 399         // Test with null
 400         try {
 401             testString("a", null, "c");
 402             fail("expected failure");
 403         }
 404         catch (IllegalArgumentException t) {
 405             // success
 406         }
 407     }
 408 
 409     enum E1 { A, B }
 410     enum E2 { C, D, E, F, G, H }
 411 
 412     public void testEnum() throws Throwable {
 413         testEnum(E1.class);
 414         testEnum(E1.class, "A");
 415         testEnum(E1.class, "A", "B");
 416         testEnum(E1.class, "B", "A");
 417         testEnum(E2.class, "C");
 418         testEnum(E2.class, "C", "D", "E", "F", "H");
 419         testEnum(E2.class, "H", "C", "G", "D", "F", "E");
 420 
 421         // Bad enum class
 422         try {
 423             testEnum((Class) String.class, "A");
 424             fail("expected failure");
 425         }
 426         catch (IllegalArgumentException t) {
 427             // success
 428         }
 429 
 430         // Bad enum constants
 431         try {
 432             testEnum(E1.class, "B", "A", "FILE_NOT_FOUND");
 433             fail("expected failure");
 434         }
 435         catch (IllegalArgumentException t) {
 436             // success
 437         }
 438 
 439         // Null enum constant
 440         try {
 441             testEnum(E1.class, "A", null, "B");
 442             fail("expected failure");
 443         }
 444         catch (IllegalArgumentException t) {
 445             // success
 446         }
 447     }
 448 
 449     private void testType(Object target, int result, Class... labels) throws Throwable {
 450         MethodHandle indy = ((CallSite) BSM_TYPE_SWITCH.invoke(MethodHandles.lookup(), "", switchType(Object.class), labels)).dynamicInvoker();
 451         assertEquals((int) indy.invoke(target), result);
 452         assertEquals(-1, (int) indy.invoke(null));
 453     }
 454 
 455     public void testTypes() throws Throwable {
 456         testType("", 0, String.class, Object.class);
 457         testType("", 0, Object.class);
 458         testType("", 1, Integer.class);
 459         testType("", 1, Integer.class, Serializable.class);
 460         testType(E1.A, 0, E1.class, Object.class);
 461         testType(E2.C, 1, E1.class, Object.class);
 462         testType(new Serializable() { }, 1, Comparable.class, Serializable.class);
 463 
 464         // test failures: duplicates, nulls, dominance inversion
 465     }
 466 }