1 /*
  2  * Copyright (c) 2017, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package java.lang.runtime;
 27 
 28 import java.lang.Enum.EnumDesc;
 29 import java.lang.constant.ClassDesc;
 30 import java.lang.constant.ConstantDescs;
 31 import java.lang.constant.MethodTypeDesc;
 32 import java.lang.invoke.CallSite;
 33 import java.lang.invoke.ConstantCallSite;
 34 import java.lang.invoke.MethodHandle;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.invoke.MethodType;
 37 import java.lang.reflect.AccessFlag;
 38 import java.util.ArrayList;
 39 import java.util.List;
 40 import java.util.Objects;
 41 import java.util.Optional;
 42 import java.util.function.BiPredicate;
 43 import java.util.stream.Stream;
 44 import jdk.internal.access.SharedSecrets;
 45 import java.lang.classfile.ClassFile;
 46 import java.lang.classfile.Label;
 47 import java.lang.classfile.instruction.SwitchCase;
 48 import jdk.internal.vm.annotation.Stable;
 49 
 50 import static java.lang.invoke.MethodHandles.Lookup.ClassOption.NESTMATE;
 51 import static java.lang.invoke.MethodHandles.Lookup.ClassOption.STRONG;
 52 import static java.util.Objects.requireNonNull;
 53 import jdk.internal.misc.PreviewFeatures;
 54 
 55 /**
 56  * Bootstrap methods for linking {@code invokedynamic} call sites that implement
 57  * the selection functionality of the {@code switch} statement.  The bootstraps
 58  * take additional static arguments corresponding to the {@code case} labels
 59  * of the {@code switch}, implicitly numbered sequentially from {@code [0..N)}.
 60  *
 61  * @since 21
 62  */
 63 public class SwitchBootstraps {
 64 
 65     private SwitchBootstraps() {}
 66 
 67     private static final Object SENTINEL = new Object();
 68     private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
 69 
 70     private static final MethodHandle NULL_CHECK;
 71     private static final MethodHandle IS_ZERO;
 72     private static final MethodHandle CHECK_INDEX;
 73     private static final MethodHandle MAPPED_ENUM_LOOKUP;
 74 
 75     private static final MethodTypeDesc TYPES_SWITCH_DESCRIPTOR =
 76             MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;Ljava/util/List;)I");
 77 
 78     static {
 79         try {
 80             NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull",
 81                                            MethodType.methodType(boolean.class, Object.class));
 82             IS_ZERO = LOOKUP.findStatic(SwitchBootstraps.class, "isZero",
 83                                            MethodType.methodType(boolean.class, int.class));
 84             CHECK_INDEX = LOOKUP.findStatic(Objects.class, "checkIndex",
 85                                            MethodType.methodType(int.class, int.class, int.class));
 86             MAPPED_ENUM_LOOKUP = LOOKUP.findStatic(SwitchBootstraps.class, "mappedEnumLookup",
 87                                                    MethodType.methodType(int.class, Enum.class, MethodHandles.Lookup.class,
 88                                                                          Class.class, EnumDesc[].class, EnumMap.class));
 89         }
 90         catch (ReflectiveOperationException e) {
 91             throw new ExceptionInInitializerError(e);
 92         }
 93     }
 94 
 95     /**
 96      * Bootstrap method for linking an {@code invokedynamic} call site that
 97      * implements a {@code switch} on a target of a reference type.  The static
 98      * arguments are an array of case labels which must be non-null and of type
 99      * {@code String} or {@code Integer} or {@code Class} or {@code EnumDesc}.
100      * <p>
101      * The type of the returned {@code CallSite}'s method handle will have
102      * a return type of {@code int}.   It has two parameters: the first argument
103      * will be an {@code Object} instance ({@code target}) and the second
104      * will be {@code int} ({@code restart}).
105      * <p>
106      * If the {@code target} is {@code null}, then the method of the call site
107      * returns {@literal -1}.
108      * <p>
109      * If the {@code target} is not {@code null}, then the method of the call site
110      * returns the index of the first element in the {@code labels} array starting from
111      * the {@code restart} index matching one of the following conditions:
112      * <ul>
113      *   <li>the element is of type {@code Class} that is assignable
114      *       from the target's class; or</li>
115      *   <li>the element is of type {@code String} or {@code Integer} and
116      *       equals to the target.</li>
117      *   <li>the element is of type {@code EnumDesc}, that describes a constant that is
118      *       equals to the target.</li>
119      * </ul>
120      * <p>
121      * If no element in the {@code labels} array matches the target, then
122      * the method of the call site return the length of the {@code labels} array.
123      * <p>
124      * The value of the {@code restart} index must be between {@code 0} (inclusive) and
125      * the length of the {@code labels} array (inclusive),
126      * both  or an {@link IndexOutOfBoundsException} is thrown.
127      *
128      * @param lookup Represents a lookup context with the accessibility
129      *               privileges of the caller.  When used with {@code invokedynamic},
130      *               this is stacked automatically by the VM.
131      * @param invocationName unused
132      * @param invocationType The invocation type of the {@code CallSite} with two parameters,
133      *                       a reference type, an {@code int}, and {@code int} as a return type.
134      * @param labels case labels - {@code String} and {@code Integer} constants
135      *               and {@code Class} and {@code EnumDesc} instances, in any combination
136      * @return a {@code CallSite} returning the first matching element as described above
137      *
138      * @throws NullPointerException if any argument is {@code null}
139      * @throws IllegalArgumentException if any element in the labels array is null, if the
140      * invocation type is not not a method type of first parameter of a reference type,
141      * second parameter of type {@code int} and with {@code int} as its return type,
142      * or if {@code labels} contains an element that is not of type {@code String},
143      * {@code Integer}, {@code Class} or {@code EnumDesc}.
144      * @jvms 4.4.6 The CONSTANT_NameAndType_info Structure
145      * @jvms 4.4.10 The CONSTANT_Dynamic_info and CONSTANT_InvokeDynamic_info Structures
146      */
147     public static CallSite typeSwitch(MethodHandles.Lookup lookup,
148                                       String invocationName,
149                                       MethodType invocationType,
150                                       Object... labels) {
151         if (invocationType.parameterCount() != 2
152             || (!invocationType.returnType().equals(int.class))
153             || invocationType.parameterType(0).isPrimitive()
154             || !invocationType.parameterType(1).equals(int.class))
155             throw new IllegalArgumentException("Illegal invocation type " + invocationType);
156         requireNonNull(labels);
157 
158         labels = labels.clone();
159         Stream.of(labels).forEach(SwitchBootstraps::verifyLabel);
160 
161         MethodHandle target = generateInnerClass(lookup, labels);
162 
163         target = withIndexCheck(target, labels.length);
164 
165         return new ConstantCallSite(target);
166     }
167 
168     private static void verifyLabel(Object label) {
169         if (label == null) {
170             throw new IllegalArgumentException("null label found");
171         }
172         Class<?> labelClass = label.getClass();
173         if (labelClass != Class.class &&
174             labelClass != String.class &&
175             labelClass != Integer.class &&
176             labelClass != EnumDesc.class) {
177             throw new IllegalArgumentException("label with illegal type found: " + label.getClass());
178         }
179     }
180 
181     private static boolean isZero(int value) {
182         return value == 0;
183     }
184 
185     /**
186      * Bootstrap method for linking an {@code invokedynamic} call site that
187      * implements a {@code switch} on a target of an enum type. The static
188      * arguments are used to encode the case labels associated to the switch
189      * construct, where each label can be encoded in two ways:
190      * <ul>
191      *   <li>as a {@code String} value, which represents the name of
192      *       the enum constant associated with the label</li>
193      *   <li>as a {@code Class} value, which represents the enum type
194      *       associated with a type test pattern</li>
195      * </ul>
196      * <p>
197      * The returned {@code CallSite}'s method handle will have
198      * a return type of {@code int} and accepts two parameters: the first argument
199      * will be an {@code Enum} instance ({@code target}) and the second
200      * will be {@code int} ({@code restart}).
201      * <p>
202      * If the {@code target} is {@code null}, then the method of the call site
203      * returns {@literal -1}.
204      * <p>
205      * If the {@code target} is not {@code null}, then the method of the call site
206      * returns the index of the first element in the {@code labels} array starting from
207      * the {@code restart} index matching one of the following conditions:
208      * <ul>
209      *   <li>the element is of type {@code Class} that is assignable
210      *       from the target's class; or</li>
211      *   <li>the element is of type {@code String} and equals to the target
212      *       enum constant's {@link Enum#name()}.</li>
213      * </ul>
214      * <p>
215      * If no element in the {@code labels} array matches the target, then
216      * the method of the call site return the length of the {@code labels} array.
217      * <p>
218      * The value of the {@code restart} index must be between {@code 0} (inclusive) and
219      * the length of the {@code labels} array (inclusive),
220      * both  or an {@link IndexOutOfBoundsException} is thrown.
221      *
222      * @param lookup Represents a lookup context with the accessibility
223      *               privileges of the caller. When used with {@code invokedynamic},
224      *               this is stacked automatically by the VM.
225      * @param invocationName unused
226      * @param invocationType The invocation type of the {@code CallSite} with two parameters,
227      *                       an enum type, an {@code int}, and {@code int} as a return type.
228      * @param labels case labels - {@code String} constants and {@code Class} instances,
229      *               in any combination
230      * @return a {@code CallSite} returning the first matching element as described above
231      *
232      * @throws NullPointerException if any argument is {@code null}
233      * @throws IllegalArgumentException if any element in the labels array is null, if the
234      * invocation type is not a method type whose first parameter type is an enum type,
235      * second parameter of type {@code int} and whose return type is {@code int},
236      * or if {@code labels} contains an element that is not of type {@code String} or
237      * {@code Class} of the target enum type.
238      * @jvms 4.4.6 The CONSTANT_NameAndType_info Structure
239      * @jvms 4.4.10 The CONSTANT_Dynamic_info and CONSTANT_InvokeDynamic_info Structures
240      */
241     public static CallSite enumSwitch(MethodHandles.Lookup lookup,
242                                       String invocationName,
243                                       MethodType invocationType,
244                                       Object... labels) {
245         if (invocationType.parameterCount() != 2
246             || (!invocationType.returnType().equals(int.class))
247             || invocationType.parameterType(0).isPrimitive()
248             || !invocationType.parameterType(0).isEnum()
249             || !invocationType.parameterType(1).equals(int.class))
250             throw new IllegalArgumentException("Illegal invocation type " + invocationType);
251         requireNonNull(labels);
252 
253         labels = labels.clone();
254 
255         Class<?> enumClass = invocationType.parameterType(0);
256         labels = Stream.of(labels).map(l -> convertEnumConstants(lookup, enumClass, l)).toArray();
257 
258         MethodHandle target;
259         boolean constantsOnly = Stream.of(labels).allMatch(l -> enumClass.isAssignableFrom(EnumDesc.class));
260 
261         if (labels.length > 0 && constantsOnly) {
262             //If all labels are enum constants, construct an optimized handle for repeat index 0:
263             //if (selector == null) return -1
264             //else if (idx == 0) return mappingArray[selector.ordinal()]; //mapping array created lazily
265             //else return "typeSwitch(labels)"
266             MethodHandle body =
267                     MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class),
268                                                 MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class),
269                                                 MethodHandles.guardWithTest(MethodHandles.dropArguments(IS_ZERO, 1, Object.class),
270                                                                             generateInnerClass(lookup, labels),
271                                                                             MethodHandles.insertArguments(MAPPED_ENUM_LOOKUP, 1, lookup, enumClass, labels, new EnumMap())));
272             target = MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0);
273         } else {
274             target = generateInnerClass(lookup, labels);
275         }
276 
277         target = target.asType(invocationType);
278         target = withIndexCheck(target, labels.length);
279 
280         return new ConstantCallSite(target);
281     }
282 
283     private static <E extends Enum<E>> Object convertEnumConstants(MethodHandles.Lookup lookup, Class<?> enumClassTemplate, Object label) {
284         if (label == null) {
285             throw new IllegalArgumentException("null label found");
286         }
287         Class<?> labelClass = label.getClass();
288         if (labelClass == Class.class) {
289             if (label != enumClassTemplate) {
290                 throw new IllegalArgumentException("the Class label: " + label +
291                                                    ", expected the provided enum class: " + enumClassTemplate);
292             }
293             return label;
294         } else if (labelClass == String.class) {
295             return EnumDesc.of(enumClassTemplate.describeConstable().orElseThrow(), (String) label);
296         } else {
297             throw new IllegalArgumentException("label with illegal type found: " + labelClass +
298                                                ", expected label of type either String or Class");
299         }
300     }
301 
302     private static <T extends Enum<T>> int mappedEnumLookup(T value, MethodHandles.Lookup lookup, Class<T> enumClass, EnumDesc<?>[] labels, EnumMap enumMap) {
303         if (enumMap.map == null) {
304             T[] constants = SharedSecrets.getJavaLangAccess().getEnumConstantsShared(enumClass);
305             int[] map = new int[constants.length];
306             int ordinal = 0;
307 
308             for (T constant : constants) {
309                 map[ordinal] = labels.length;
310 
311                 for (int i = 0; i < labels.length; i++) {
312                     if (Objects.equals(labels[i].constantName(), constant.name())) {
313                         map[ordinal] = i;
314                         break;
315                     }
316                 }
317 
318                 ordinal++;
319             }
320         }
321         return enumMap.map[value.ordinal()];
322     }
323 
324     private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) {
325         MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1);
326 
327         return MethodHandles.filterArguments(target, 1, checkIndex);
328     }
329 
330     private static final class ResolvedEnumLabels implements BiPredicate<Integer, Object> {
331 
332         private final MethodHandles.Lookup lookup;
333         private final EnumDesc<?>[] enumDescs;
334         @Stable
335         private Object[] resolvedEnum;
336 
337         public ResolvedEnumLabels(MethodHandles.Lookup lookup, EnumDesc<?>[] enumDescs) {
338             this.lookup = lookup;
339             this.enumDescs = enumDescs;
340             this.resolvedEnum = new Object[enumDescs.length];
341         }
342 
343         @Override
344         public boolean test(Integer labelIndex, Object value) {
345             Object result = resolvedEnum[labelIndex];
346 
347             if (result == null) {
348                 try {
349                     if (!(value instanceof Enum<?> enumValue)) {
350                         return false;
351                     }
352 
353                     EnumDesc<?> label = enumDescs[labelIndex];
354                     Class<?> clazz = label.constantType().resolveConstantDesc(lookup);
355 
356                     if (enumValue.getDeclaringClass() != clazz) {
357                         return false;
358                     }
359 
360                     result = label.resolveConstantDesc(lookup);
361                 } catch (IllegalArgumentException | ReflectiveOperationException ex) {
362                     result = SENTINEL;
363                 }
364 
365                 resolvedEnum[labelIndex] = result;
366             }
367 
368             return result == value;
369         }
370     }
371 
372     private static final class EnumMap {
373         @Stable
374         public int[] map;
375     }
376 
377     /*
378      * Construct test chains for labels inside switch, to handle switch repeats:
379      * switch (idx) {
380      *     case 0 -> if (selector matches label[0]) return 0;
381      *     case 1 -> if (selector matches label[1]) return 1;
382      *     ...
383      * }
384      */
385     @SuppressWarnings("removal")
386     private static MethodHandle generateInnerClass(MethodHandles.Lookup caller, Object[] labels) {
387         List<EnumDesc<?>> enumDescs = new ArrayList<>();
388         List<Class<?>> extraClassLabels = new ArrayList<>();
389 
390         byte[] classBytes = ClassFile.of().build(ClassDesc.of(typeSwitchClassName(caller.lookupClass())), clb -> {
391             clb.withFlags(AccessFlag.FINAL, (PreviewFeatures.isEnabled())  ? AccessFlag.IDENTITY : AccessFlag.SUPER, AccessFlag.SYNTHETIC)
392                .withMethodBody("typeSwitch",
393                                TYPES_SWITCH_DESCRIPTOR,
394                                ClassFile.ACC_FINAL | ClassFile.ACC_PUBLIC | ClassFile.ACC_STATIC,
395                                cb -> {
396                     cb.aload(0);
397                     Label nonNullLabel = cb.newLabel();
398                     cb.if_nonnull(nonNullLabel);
399                     cb.iconst_m1();
400                     cb.ireturn();
401                     cb.labelBinding(nonNullLabel);
402                     if (labels.length == 0) {
403                         cb.constantInstruction(0)
404                           .ireturn();
405                         return ;
406                     }
407                     cb.iload(1);
408                     Label dflt = cb.newLabel();
409                     record Element(Label target, Label next, Object caseLabel) {}
410                     List<Element> cases = new ArrayList<>();
411                     List<SwitchCase> switchCases = new ArrayList<>();
412                     Object lastLabel = null;
413                     for (int idx = labels.length - 1; idx >= 0; idx--) {
414                         Object currentLabel = labels[idx];
415                         Label target = cb.newLabel();
416                         Label next;
417                         if (lastLabel == null) {
418                             next = dflt;
419                         } else if (lastLabel.equals(currentLabel)) {
420                             next = cases.getLast().next();
421                         } else {
422                             next = cases.getLast().target();
423                         }
424                         lastLabel = currentLabel;
425                         cases.add(new Element(target, next, currentLabel));
426                         switchCases.add(SwitchCase.of(idx, target));
427                     }
428                     cases = cases.reversed();
429                     switchCases = switchCases.reversed();
430                     cb.tableswitch(0, labels.length - 1, dflt, switchCases);
431                     for (int idx = 0; idx < cases.size(); idx++) {
432                         Element element = cases.get(idx);
433                         Label next = element.next();
434                         cb.labelBinding(element.target());
435                         if (element.caseLabel() instanceof Class<?> classLabel) {
436                             Optional<ClassDesc> classLabelConstableOpt = classLabel.describeConstable();
437                             if (classLabelConstableOpt.isPresent()) {
438                                 cb.aload(0);
439                                 cb.instanceof_(classLabelConstableOpt.orElseThrow());
440                                 cb.ifeq(next);
441                             } else {
442                                 cb.aload(3);
443                                 cb.constantInstruction(extraClassLabels.size());
444                                 cb.invokeinterface(ConstantDescs.CD_List,
445                                                    "get",
446                                                    MethodTypeDesc.of(ConstantDescs.CD_Object,
447                                                                      ConstantDescs.CD_int));
448                                 cb.checkcast(ConstantDescs.CD_Class);
449                                 cb.aload(0);
450                                 cb.invokevirtual(ConstantDescs.CD_Class,
451                                                  "isInstance",
452                                                  MethodTypeDesc.of(ConstantDescs.CD_boolean,
453                                                                    ConstantDescs.CD_Object));
454                                 cb.ifeq(next);
455                                 extraClassLabels.add(classLabel);
456                             }
457                         } else if (element.caseLabel() instanceof EnumDesc<?> enumLabel) {
458                             int enumIdx = enumDescs.size();
459                             enumDescs.add(enumLabel);
460                             cb.aload(2);
461                             cb.constantInstruction(enumIdx);
462                             cb.invokestatic(ConstantDescs.CD_Integer,
463                                             "valueOf",
464                                             MethodTypeDesc.of(ConstantDescs.CD_Integer,
465                                                               ConstantDescs.CD_int));
466                             cb.aload(0);
467                             cb.invokeinterface(BiPredicate.class.describeConstable().orElseThrow(),
468                                                "test",
469                                                MethodTypeDesc.of(ConstantDescs.CD_boolean,
470                                                                  ConstantDescs.CD_Object,
471                                                                  ConstantDescs.CD_Object));
472                             cb.ifeq(next);
473                         } else if (element.caseLabel() instanceof String stringLabel) {
474                             cb.ldc(stringLabel);
475                             cb.aload(0);
476                             cb.invokevirtual(ConstantDescs.CD_Object,
477                                              "equals",
478                                              MethodTypeDesc.of(ConstantDescs.CD_boolean,
479                                                                ConstantDescs.CD_Object));
480                             cb.ifeq(next);
481                         } else if (element.caseLabel() instanceof Integer integerLabel) {
482                             Label compare = cb.newLabel();
483                             Label notNumber = cb.newLabel();
484                             cb.aload(0);
485                             cb.instanceof_(ConstantDescs.CD_Number);
486                             cb.ifeq(notNumber);
487                             cb.aload(0);
488                             cb.checkcast(ConstantDescs.CD_Number);
489                             cb.invokevirtual(ConstantDescs.CD_Number,
490                                              "intValue",
491                                              MethodTypeDesc.of(ConstantDescs.CD_int));
492                             cb.goto_(compare);
493                             cb.labelBinding(notNumber);
494                             cb.aload(0);
495                             cb.instanceof_(ConstantDescs.CD_Character);
496                             cb.ifeq(next);
497                             cb.aload(0);
498                             cb.checkcast(ConstantDescs.CD_Character);
499                             cb.invokevirtual(ConstantDescs.CD_Character,
500                                              "charValue",
501                                              MethodTypeDesc.of(ConstantDescs.CD_char));
502                             cb.labelBinding(compare);
503                             cb.ldc(integerLabel);
504                             cb.if_icmpne(next);
505                         } else {
506                             throw new InternalError("Unsupported label type: " +
507                                                     element.caseLabel().getClass());
508                         }
509                         cb.constantInstruction(idx);
510                         cb.ireturn();
511                     }
512                     cb.labelBinding(dflt);
513                     cb.constantInstruction(cases.size());
514                     cb.ireturn();
515                 });
516         });
517 
518         try {
519             // this class is linked at the indy callsite; so define a hidden nestmate
520             MethodHandles.Lookup lookup;
521             lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG);
522             MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(),
523                                                         "typeSwitch",
524                                                         MethodType.methodType(int.class,
525                                                                               Object.class,
526                                                                               int.class,
527                                                                               BiPredicate.class,
528                                                                               List.class));
529             return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(EnumDesc[]::new)),
530                                                                 List.copyOf(extraClassLabels));
531         } catch (Throwable t) {
532             throw new IllegalArgumentException(t);
533         }
534     }
535 
536     //based on src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java:
537     private static String typeSwitchClassName(Class<?> targetClass) {
538         String name = targetClass.getName();
539         if (targetClass.isHidden()) {
540             // use the original class name
541             name = name.replace('/', '_');
542         }
543         return name + "$$TypeSwitch";
544     }
545 }