1 /*
  2  * Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package java.lang.runtime;
 27 
 28 import java.lang.Enum.EnumDesc;
 29 import java.lang.constant.ClassDesc;
 30 import java.lang.constant.ConstantDescs;
 31 import java.lang.constant.MethodTypeDesc;
 32 import java.lang.invoke.CallSite;
 33 import java.lang.invoke.ConstantCallSite;
 34 import java.lang.invoke.MethodHandle;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.invoke.MethodType;
 37 import java.lang.reflect.AccessFlag;
 38 import java.util.ArrayList;
 39 import java.util.List;
 40 import java.util.Objects;
 41 import java.util.Optional;
 42 import java.util.function.BiPredicate;
 43 import java.util.stream.Stream;
 44 import jdk.internal.access.SharedSecrets;
 45 import java.lang.classfile.ClassFile;
 46 import java.lang.classfile.Label;
 47 import java.lang.classfile.instruction.SwitchCase;
 48 import jdk.internal.vm.annotation.Stable;
 49 
 50 import static java.lang.invoke.MethodHandles.Lookup.ClassOption.NESTMATE;
 51 import static java.lang.invoke.MethodHandles.Lookup.ClassOption.STRONG;
 52 import static java.util.Objects.requireNonNull;
 53 
 54 /**
 55  * Bootstrap methods for linking {@code invokedynamic} call sites that implement
 56  * the selection functionality of the {@code switch} statement.  The bootstraps
 57  * take additional static arguments corresponding to the {@code case} labels
 58  * of the {@code switch}, implicitly numbered sequentially from {@code [0..N)}.
 59  *
 60  * @since 21
 61  */
 62 public class SwitchBootstraps {
 63 
 64     private SwitchBootstraps() {}
 65 
 66     private static final Object SENTINEL = new Object();
 67     private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
 68 
 69     private static final MethodHandle NULL_CHECK;
 70     private static final MethodHandle IS_ZERO;
 71     private static final MethodHandle CHECK_INDEX;
 72     private static final MethodHandle MAPPED_ENUM_LOOKUP;
 73 
 74     private static final MethodTypeDesc TYPES_SWITCH_DESCRIPTOR =
 75             MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;Ljava/util/List;)I");
 76 
 77     static {
 78         try {
 79             NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull",
 80                                            MethodType.methodType(boolean.class, Object.class));
 81             IS_ZERO = LOOKUP.findStatic(SwitchBootstraps.class, "isZero",
 82                                            MethodType.methodType(boolean.class, int.class));
 83             CHECK_INDEX = LOOKUP.findStatic(Objects.class, "checkIndex",
 84                                            MethodType.methodType(int.class, int.class, int.class));
 85             MAPPED_ENUM_LOOKUP = LOOKUP.findStatic(SwitchBootstraps.class, "mappedEnumLookup",
 86                                                    MethodType.methodType(int.class, Enum.class, MethodHandles.Lookup.class,
 87                                                                          Class.class, EnumDesc[].class, EnumMap.class));
 88         }
 89         catch (ReflectiveOperationException e) {
 90             throw new ExceptionInInitializerError(e);
 91         }
 92     }
 93 
 94     /**
 95      * Bootstrap method for linking an {@code invokedynamic} call site that
 96      * implements a {@code switch} on a target of a reference type.  The static
 97      * arguments are an array of case labels which must be non-null and of type
 98      * {@code String} or {@code Integer} or {@code Class} or {@code EnumDesc}.
 99      * <p>
100      * The type of the returned {@code CallSite}'s method handle will have
101      * a return type of {@code int}.   It has two parameters: the first argument
102      * will be an {@code Object} instance ({@code target}) and the second
103      * will be {@code int} ({@code restart}).
104      * <p>
105      * If the {@code target} is {@code null}, then the method of the call site
106      * returns {@literal -1}.
107      * <p>
108      * If the {@code target} is not {@code null}, then the method of the call site
109      * returns the index of the first element in the {@code labels} array starting from
110      * the {@code restart} index matching one of the following conditions:
111      * <ul>
112      *   <li>the element is of type {@code Class} that is assignable
113      *       from the target's class; or</li>
114      *   <li>the element is of type {@code String} or {@code Integer} and
115      *       equals to the target.</li>
116      *   <li>the element is of type {@code EnumDesc}, that describes a constant that is
117      *       equals to the target.</li>
118      * </ul>
119      * <p>
120      * If no element in the {@code labels} array matches the target, then
121      * the method of the call site return the length of the {@code labels} array.
122      * <p>
123      * The value of the {@code restart} index must be between {@code 0} (inclusive) and
124      * the length of the {@code labels} array (inclusive),
125      * both  or an {@link IndexOutOfBoundsException} is thrown.
126      *
127      * @param lookup Represents a lookup context with the accessibility
128      *               privileges of the caller.  When used with {@code invokedynamic},
129      *               this is stacked automatically by the VM.
130      * @param invocationName unused
131      * @param invocationType The invocation type of the {@code CallSite} with two parameters,
132      *                       a reference type, an {@code int}, and {@code int} as a return type.
133      * @param labels case labels - {@code String} and {@code Integer} constants
134      *               and {@code Class} and {@code EnumDesc} instances, in any combination
135      * @return a {@code CallSite} returning the first matching element as described above
136      *
137      * @throws NullPointerException if any argument is {@code null}
138      * @throws IllegalArgumentException if any element in the labels array is null, if the
139      * invocation type is not not a method type of first parameter of a reference type,
140      * second parameter of type {@code int} and with {@code int} as its return type,
141      * or if {@code labels} contains an element that is not of type {@code String},
142      * {@code Integer}, {@code Class} or {@code EnumDesc}.
143      * @jvms 4.4.6 The CONSTANT_NameAndType_info Structure
144      * @jvms 4.4.10 The CONSTANT_Dynamic_info and CONSTANT_InvokeDynamic_info Structures
145      */
146     public static CallSite typeSwitch(MethodHandles.Lookup lookup,
147                                       String invocationName,
148                                       MethodType invocationType,
149                                       Object... labels) {
150         if (invocationType.parameterCount() != 2
151             || (!invocationType.returnType().equals(int.class))
152             || invocationType.parameterType(0).isPrimitive()
153             || !invocationType.parameterType(1).equals(int.class))
154             throw new IllegalArgumentException("Illegal invocation type " + invocationType);
155         requireNonNull(labels);
156 
157         labels = labels.clone();
158         Stream.of(labels).forEach(SwitchBootstraps::verifyLabel);
159 
160         MethodHandle target = generateInnerClass(lookup, labels);
161 
162         target = withIndexCheck(target, labels.length);
163 
164         return new ConstantCallSite(target);
165     }
166 
167     private static void verifyLabel(Object label) {
168         if (label == null) {
169             throw new IllegalArgumentException("null label found");
170         }
171         Class<?> labelClass = label.getClass();
172         if (labelClass != Class.class &&
173             labelClass != String.class &&
174             labelClass != Integer.class &&
175             labelClass != EnumDesc.class) {
176             throw new IllegalArgumentException("label with illegal type found: " + label.getClass());
177         }
178     }
179 
180     private static boolean isZero(int value) {
181         return value == 0;
182     }
183 
184     /**
185      * Bootstrap method for linking an {@code invokedynamic} call site that
186      * implements a {@code switch} on a target of an enum type. The static
187      * arguments are used to encode the case labels associated to the switch
188      * construct, where each label can be encoded in two ways:
189      * <ul>
190      *   <li>as a {@code String} value, which represents the name of
191      *       the enum constant associated with the label</li>
192      *   <li>as a {@code Class} value, which represents the enum type
193      *       associated with a type test pattern</li>
194      * </ul>
195      * <p>
196      * The returned {@code CallSite}'s method handle will have
197      * a return type of {@code int} and accepts two parameters: the first argument
198      * will be an {@code Enum} instance ({@code target}) and the second
199      * will be {@code int} ({@code restart}).
200      * <p>
201      * If the {@code target} is {@code null}, then the method of the call site
202      * returns {@literal -1}.
203      * <p>
204      * If the {@code target} is not {@code null}, then the method of the call site
205      * returns the index of the first element in the {@code labels} array starting from
206      * the {@code restart} index matching one of the following conditions:
207      * <ul>
208      *   <li>the element is of type {@code Class} that is assignable
209      *       from the target's class; or</li>
210      *   <li>the element is of type {@code String} and equals to the target
211      *       enum constant's {@link Enum#name()}.</li>
212      * </ul>
213      * <p>
214      * If no element in the {@code labels} array matches the target, then
215      * the method of the call site return the length of the {@code labels} array.
216      * <p>
217      * The value of the {@code restart} index must be between {@code 0} (inclusive) and
218      * the length of the {@code labels} array (inclusive),
219      * both  or an {@link IndexOutOfBoundsException} is thrown.
220      *
221      * @param lookup Represents a lookup context with the accessibility
222      *               privileges of the caller. When used with {@code invokedynamic},
223      *               this is stacked automatically by the VM.
224      * @param invocationName unused
225      * @param invocationType The invocation type of the {@code CallSite} with two parameters,
226      *                       an enum type, an {@code int}, and {@code int} as a return type.
227      * @param labels case labels - {@code String} constants and {@code Class} instances,
228      *               in any combination
229      * @return a {@code CallSite} returning the first matching element as described above
230      *
231      * @throws NullPointerException if any argument is {@code null}
232      * @throws IllegalArgumentException if any element in the labels array is null, if the
233      * invocation type is not a method type whose first parameter type is an enum type,
234      * second parameter of type {@code int} and whose return type is {@code int},
235      * or if {@code labels} contains an element that is not of type {@code String} or
236      * {@code Class} of the target enum type.
237      * @jvms 4.4.6 The CONSTANT_NameAndType_info Structure
238      * @jvms 4.4.10 The CONSTANT_Dynamic_info and CONSTANT_InvokeDynamic_info Structures
239      */
240     public static CallSite enumSwitch(MethodHandles.Lookup lookup,
241                                       String invocationName,
242                                       MethodType invocationType,
243                                       Object... labels) {
244         if (invocationType.parameterCount() != 2
245             || (!invocationType.returnType().equals(int.class))
246             || invocationType.parameterType(0).isPrimitive()
247             || !invocationType.parameterType(0).isEnum()
248             || !invocationType.parameterType(1).equals(int.class))
249             throw new IllegalArgumentException("Illegal invocation type " + invocationType);
250         requireNonNull(labels);
251 
252         labels = labels.clone();
253 
254         Class<?> enumClass = invocationType.parameterType(0);
255         labels = Stream.of(labels).map(l -> convertEnumConstants(lookup, enumClass, l)).toArray();
256 
257         MethodHandle target;
258         boolean constantsOnly = Stream.of(labels).allMatch(l -> enumClass.isAssignableFrom(EnumDesc.class));
259 
260         if (labels.length > 0 && constantsOnly) {
261             //If all labels are enum constants, construct an optimized handle for repeat index 0:
262             //if (selector == null) return -1
263             //else if (idx == 0) return mappingArray[selector.ordinal()]; //mapping array created lazily
264             //else return "typeSwitch(labels)"
265             MethodHandle body =
266                     MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class),
267                                                 MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class),
268                                                 MethodHandles.guardWithTest(MethodHandles.dropArguments(IS_ZERO, 1, Object.class),
269                                                                             generateInnerClass(lookup, labels),
270                                                                             MethodHandles.insertArguments(MAPPED_ENUM_LOOKUP, 1, lookup, enumClass, labels, new EnumMap())));
271             target = MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0);
272         } else {
273             target = generateInnerClass(lookup, labels);
274         }
275 
276         target = target.asType(invocationType);
277         target = withIndexCheck(target, labels.length);
278 
279         return new ConstantCallSite(target);
280     }
281 
282     private static <E extends Enum<E>> Object convertEnumConstants(MethodHandles.Lookup lookup, Class<?> enumClassTemplate, Object label) {
283         if (label == null) {
284             throw new IllegalArgumentException("null label found");
285         }
286         Class<?> labelClass = label.getClass();
287         if (labelClass == Class.class) {
288             if (label != enumClassTemplate) {
289                 throw new IllegalArgumentException("the Class label: " + label +
290                                                    ", expected the provided enum class: " + enumClassTemplate);
291             }
292             return label;
293         } else if (labelClass == String.class) {
294             return EnumDesc.of(enumClassTemplate.describeConstable().orElseThrow(), (String) label);
295         } else {
296             throw new IllegalArgumentException("label with illegal type found: " + labelClass +
297                                                ", expected label of type either String or Class");
298         }
299     }
300 
301     private static <T extends Enum<T>> int mappedEnumLookup(T value, MethodHandles.Lookup lookup, Class<T> enumClass, EnumDesc<?>[] labels, EnumMap enumMap) {
302         if (enumMap.map == null) {
303             T[] constants = SharedSecrets.getJavaLangAccess().getEnumConstantsShared(enumClass);
304             int[] map = new int[constants.length];
305             int ordinal = 0;
306 
307             for (T constant : constants) {
308                 map[ordinal] = labels.length;
309 
310                 for (int i = 0; i < labels.length; i++) {
311                     if (Objects.equals(labels[i].constantName(), constant.name())) {
312                         map[ordinal] = i;
313                         break;
314                     }
315                 }
316 
317                 ordinal++;
318             }
319         }
320         return enumMap.map[value.ordinal()];
321     }
322 
323     private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) {
324         MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1);
325 
326         return MethodHandles.filterArguments(target, 1, checkIndex);
327     }
328 
329     private static final class ResolvedEnumLabels implements BiPredicate<Integer, Object> {
330 
331         private final MethodHandles.Lookup lookup;
332         private final EnumDesc<?>[] enumDescs;
333         @Stable
334         private Object[] resolvedEnum;
335 
336         public ResolvedEnumLabels(MethodHandles.Lookup lookup, EnumDesc<?>[] enumDescs) {
337             this.lookup = lookup;
338             this.enumDescs = enumDescs;
339             this.resolvedEnum = new Object[enumDescs.length];
340         }
341 
342         @Override
343         public boolean test(Integer labelIndex, Object value) {
344             Object result = resolvedEnum[labelIndex];
345 
346             if (result == null) {
347                 try {
348                     if (!(value instanceof Enum<?> enumValue)) {
349                         return false;
350                     }
351 
352                     EnumDesc<?> label = enumDescs[labelIndex];
353                     Class<?> clazz = label.constantType().resolveConstantDesc(lookup);
354 
355                     if (enumValue.getDeclaringClass() != clazz) {
356                         return false;
357                     }
358 
359                     result = label.resolveConstantDesc(lookup);
360                 } catch (IllegalArgumentException | ReflectiveOperationException ex) {
361                     result = SENTINEL;
362                 }
363 
364                 resolvedEnum[labelIndex] = result;
365             }
366 
367             return result == value;
368         }
369     }
370 
371     private static final class EnumMap {
372         @Stable
373         public int[] map;
374     }
375 
376     /*
377      * Construct test chains for labels inside switch, to handle switch repeats:
378      * switch (idx) {
379      *     case 0 -> if (selector matches label[0]) return 0;
380      *     case 1 -> if (selector matches label[1]) return 1;
381      *     ...
382      * }
383      */
384     @SuppressWarnings("removal")
385     private static MethodHandle generateInnerClass(MethodHandles.Lookup caller, Object[] labels) {
386         List<EnumDesc<?>> enumDescs = new ArrayList<>();
387         List<Class<?>> extraClassLabels = new ArrayList<>();
388 
389         byte[] classBytes = ClassFile.of().build(ClassDesc.of(typeSwitchClassName(caller.lookupClass())), clb -> {
390             clb.withFlags(AccessFlag.FINAL, AccessFlag.SUPER, AccessFlag.SYNTHETIC)
391                .withMethodBody("typeSwitch",
392                                TYPES_SWITCH_DESCRIPTOR,
393                                ClassFile.ACC_FINAL | ClassFile.ACC_PUBLIC | ClassFile.ACC_STATIC,
394                                cb -> {
395                     cb.aload(0);
396                     Label nonNullLabel = cb.newLabel();
397                     cb.if_nonnull(nonNullLabel);
398                     cb.iconst_m1();
399                     cb.ireturn();
400                     cb.labelBinding(nonNullLabel);
401                     if (labels.length == 0) {
402                         cb.constantInstruction(0)
403                           .ireturn();
404                         return ;
405                     }
406                     cb.iload(1);
407                     Label dflt = cb.newLabel();
408                     record Element(Label target, Label next, Object caseLabel) {}
409                     List<Element> cases = new ArrayList<>();
410                     List<SwitchCase> switchCases = new ArrayList<>();
411                     Object lastLabel = null;
412                     for (int idx = labels.length - 1; idx >= 0; idx--) {
413                         Object currentLabel = labels[idx];
414                         Label target = cb.newLabel();
415                         Label next;
416                         if (lastLabel == null) {
417                             next = dflt;
418                         } else if (lastLabel.equals(currentLabel)) {
419                             next = cases.getLast().next();
420                         } else {
421                             next = cases.getLast().target();
422                         }
423                         lastLabel = currentLabel;
424                         cases.add(new Element(target, next, currentLabel));
425                         switchCases.add(SwitchCase.of(idx, target));
426                     }
427                     cases = cases.reversed();
428                     switchCases = switchCases.reversed();
429                     cb.tableswitch(0, labels.length - 1, dflt, switchCases);
430                     for (int idx = 0; idx < cases.size(); idx++) {
431                         Element element = cases.get(idx);
432                         Label next = element.next();
433                         cb.labelBinding(element.target());
434                         if (element.caseLabel() instanceof Class<?> classLabel) {
435                             Optional<ClassDesc> classLabelConstableOpt = classLabel.describeConstable();
436                             if (classLabelConstableOpt.isPresent()) {
437                                 cb.aload(0);
438                                 cb.instanceof_(classLabelConstableOpt.orElseThrow());
439                                 cb.ifeq(next);
440                             } else {
441                                 cb.aload(3);
442                                 cb.constantInstruction(extraClassLabels.size());
443                                 cb.invokeinterface(ConstantDescs.CD_List,
444                                                    "get",
445                                                    MethodTypeDesc.of(ConstantDescs.CD_Object,
446                                                                      ConstantDescs.CD_int));
447                                 cb.checkcast(ConstantDescs.CD_Class);
448                                 cb.aload(0);
449                                 cb.invokevirtual(ConstantDescs.CD_Class,
450                                                  "isInstance",
451                                                  MethodTypeDesc.of(ConstantDescs.CD_boolean,
452                                                                    ConstantDescs.CD_Object));
453                                 cb.ifeq(next);
454                                 extraClassLabels.add(classLabel);
455                             }
456                         } else if (element.caseLabel() instanceof EnumDesc<?> enumLabel) {
457                             int enumIdx = enumDescs.size();
458                             enumDescs.add(enumLabel);
459                             cb.aload(2);
460                             cb.constantInstruction(enumIdx);
461                             cb.invokestatic(ConstantDescs.CD_Integer,
462                                             "valueOf",
463                                             MethodTypeDesc.of(ConstantDescs.CD_Integer,
464                                                               ConstantDescs.CD_int));
465                             cb.aload(0);
466                             cb.invokeinterface(BiPredicate.class.describeConstable().orElseThrow(),
467                                                "test",
468                                                MethodTypeDesc.of(ConstantDescs.CD_boolean,
469                                                                  ConstantDescs.CD_Object,
470                                                                  ConstantDescs.CD_Object));
471                             cb.ifeq(next);
472                         } else if (element.caseLabel() instanceof String stringLabel) {
473                             cb.ldc(stringLabel);
474                             cb.aload(0);
475                             cb.invokevirtual(ConstantDescs.CD_Object,
476                                              "equals",
477                                              MethodTypeDesc.of(ConstantDescs.CD_boolean,
478                                                                ConstantDescs.CD_Object));
479                             cb.ifeq(next);
480                         } else if (element.caseLabel() instanceof Integer integerLabel) {
481                             Label compare = cb.newLabel();
482                             Label notNumber = cb.newLabel();
483                             cb.aload(0);
484                             cb.instanceof_(ConstantDescs.CD_Number);
485                             cb.ifeq(notNumber);
486                             cb.aload(0);
487                             cb.checkcast(ConstantDescs.CD_Number);
488                             cb.invokevirtual(ConstantDescs.CD_Number,
489                                              "intValue",
490                                              MethodTypeDesc.of(ConstantDescs.CD_int));
491                             cb.goto_(compare);
492                             cb.labelBinding(notNumber);
493                             cb.aload(0);
494                             cb.instanceof_(ConstantDescs.CD_Character);
495                             cb.ifeq(next);
496                             cb.aload(0);
497                             cb.checkcast(ConstantDescs.CD_Character);
498                             cb.invokevirtual(ConstantDescs.CD_Character,
499                                              "charValue",
500                                              MethodTypeDesc.of(ConstantDescs.CD_char));
501                             cb.labelBinding(compare);
502                             cb.ldc(integerLabel);
503                             cb.if_icmpne(next);
504                         } else {
505                             throw new InternalError("Unsupported label type: " +
506                                                     element.caseLabel().getClass());
507                         }
508                         cb.constantInstruction(idx);
509                         cb.ireturn();
510                     }
511                     cb.labelBinding(dflt);
512                     cb.constantInstruction(cases.size());
513                     cb.ireturn();
514                 });
515         });
516 
517         try {
518             // this class is linked at the indy callsite; so define a hidden nestmate
519             MethodHandles.Lookup lookup;
520             lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG);
521             MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(),
522                                                         "typeSwitch",
523                                                         MethodType.methodType(int.class,
524                                                                               Object.class,
525                                                                               int.class,
526                                                                               BiPredicate.class,
527                                                                               List.class));
528             return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(EnumDesc[]::new)),
529                                                                 List.copyOf(extraClassLabels));
530         } catch (Throwable t) {
531             throw new IllegalArgumentException(t);
532         }
533     }
534 
535     //based on src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java:
536     private static String typeSwitchClassName(Class<?> targetClass) {
537         String name = targetClass.getName();
538         if (targetClass.isHidden()) {
539             // use the original class name
540             name = name.replace('/', '_');
541         }
542         return name + "$$TypeSwitch";
543     }
544 }