1 /*
  2  * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.lang.invoke.MethodHandle;
 25 import java.lang.invoke.MethodHandleInfo;
 26 import java.lang.invoke.MethodHandles;
 27 import java.lang.invoke.MethodType;
 28 import java.lang.invoke.VarHandle;
 29 import java.lang.invoke.WrongMethodTypeException;
 30 import java.lang.reflect.Method;
 31 import java.nio.ReadOnlyBufferException;
 32 import java.util.EnumMap;
 33 import java.util.HashMap;
 34 import java.util.List;
 35 import java.util.Map;
 36 import java.util.stream.Stream;
 37 
 38 import static java.util.stream.Collectors.toList;
 39 import static org.testng.Assert.*;
 40 
 41 abstract class VarHandleBaseTest {
 42     static final int ITERS = Integer.getInteger("iters", 1);
 43 
 44     // More resilience for Weak* tests. These operations may spuriously
 45     // fail, and so we do several attempts with delay on failure.
 46     // Be mindful of worst-case total time on test, which would be at
 47     // roughly (delay*attempts) milliseconds.
 48     //
 49     static final int WEAK_ATTEMPTS = Integer.getInteger("weakAttempts", 100);
 50     static final int WEAK_DELAY_MS = Math.max(1, Integer.getInteger("weakDelay", 1));
 51 
 52     interface ThrowingRunnable {
 53         void run() throws Throwable;
 54     }
 55 
 56     static void checkUOE(ThrowingRunnable r) {
 57         checkWithThrowable(UnsupportedOperationException.class, null, r);
 58     }
 59 
 60     static void checkUOE(Object message, ThrowingRunnable r) {
 61         checkWithThrowable(UnsupportedOperationException.class, message, r);
 62     }
 63 
 64     static void checkROBE(ThrowingRunnable r) {
 65         checkWithThrowable(ReadOnlyBufferException.class, null, r);
 66     }
 67 
 68     static void checkROBE(Object message, ThrowingRunnable r) {
 69         checkWithThrowable(ReadOnlyBufferException.class, message, r);
 70     }
 71 
 72     static void checkIOOBE(ThrowingRunnable r) {
 73         checkWithThrowable(IndexOutOfBoundsException.class, null, r);
 74     }
 75 
 76     static void checkIOOBE(Object message, ThrowingRunnable r) {
 77         checkWithThrowable(IndexOutOfBoundsException.class, message, r);
 78     }
 79 
 80     static void checkAIOOBE(ThrowingRunnable r) {
 81         checkWithThrowable(ArrayIndexOutOfBoundsException.class, null, r);
 82     }
 83 
 84     static void checkAIOOBE(Object message, ThrowingRunnable r) {
 85         checkWithThrowable(ArrayIndexOutOfBoundsException.class, message, r);
 86     }
 87 
 88     static void checkASE(ThrowingRunnable r) {
 89         checkWithThrowable(ArrayStoreException.class, null, r);
 90     }
 91 
 92     static void checkASE(Object message, ThrowingRunnable r) {
 93         checkWithThrowable(ArrayStoreException.class, message, r);
 94     }
 95 
 96     static void checkISE(ThrowingRunnable r) {
 97         checkWithThrowable(IllegalStateException.class, null, r);
 98     }
 99 
100     static void checkISE(Object message, ThrowingRunnable r) {
101         checkWithThrowable(IllegalStateException.class, message, r);
102     }
103 
104     static void checkIAE(ThrowingRunnable r) {
105         checkWithThrowable(IllegalAccessException.class, null, r);
106     }
107 
108     static void checkIAE(Object message, ThrowingRunnable r) {
109         checkWithThrowable(IllegalAccessException.class, message, r);
110     }
111 
112     static void checkWMTE(ThrowingRunnable r) {
113         checkWithThrowable(WrongMethodTypeException.class, null, r);
114     }
115 
116     static void checkWMTE(Object message, ThrowingRunnable r) {
117         checkWithThrowable(WrongMethodTypeException.class, message, r);
118     }
119 
120     static void checkCCE(ThrowingRunnable r) {
121         checkWithThrowable(ClassCastException.class, null, r);
122     }
123 
124     static void checkCCE(Object message, ThrowingRunnable r) {
125         checkWithThrowable(ClassCastException.class, message, r);
126     }
127 
128     static void checkNPE(ThrowingRunnable r) {
129         checkWithThrowable(NullPointerException.class, null, r);
130     }
131 
132     static void checkNPE(Object message, ThrowingRunnable r) {
133         checkWithThrowable(NullPointerException.class, message, r);
134     }
135 
136     static void checkWithThrowable(Class<? extends Throwable> re,
137                                    Object message,
138                                    ThrowingRunnable r) {
139         Throwable _e = null;
140         try {
141             r.run();
142         }
143         catch (Throwable e) {
144             _e = e;
145         }
146         message = message == null ? "" : message + ". ";
147         assertNotNull(_e, String.format("%sNo throwable thrown. Expected %s", message, re));
148         if (!re.isInstance(_e)) {
149             fail(String.format("%sIncorrect throwable thrown, %s. Expected %s", message, _e, re), _e);
150         }
151     }
152 
153 
154     enum TestAccessType {
155         GET,
156         SET,
157         COMPARE_AND_SET,
158         COMPARE_AND_EXCHANGE,
159         GET_AND_SET,
160         GET_AND_ADD,
161         GET_AND_BITWISE;
162     }
163 
164     enum TestAccessMode {
165         GET(TestAccessType.GET),
166         SET(TestAccessType.SET),
167         GET_VOLATILE(TestAccessType.GET),
168         SET_VOLATILE(TestAccessType.SET),
169         GET_ACQUIRE(TestAccessType.GET),
170         SET_RELEASE(TestAccessType.SET),
171         GET_OPAQUE(TestAccessType.GET),
172         SET_OPAQUE(TestAccessType.SET),
173         COMPARE_AND_SET(TestAccessType.COMPARE_AND_SET),
174         COMPARE_AND_EXCHANGE(TestAccessType.COMPARE_AND_EXCHANGE),
175         COMPARE_AND_EXCHANGE_ACQUIRE(TestAccessType.COMPARE_AND_EXCHANGE),
176         COMPARE_AND_EXCHANGE_RELEASE(TestAccessType.COMPARE_AND_EXCHANGE),
177         WEAK_COMPARE_AND_SET_PLAIN(TestAccessType.COMPARE_AND_SET),
178         WEAK_COMPARE_AND_SET(TestAccessType.COMPARE_AND_SET),
179         WEAK_COMPARE_AND_SET_ACQUIRE(TestAccessType.COMPARE_AND_SET),
180         WEAK_COMPARE_AND_SET_RELEASE(TestAccessType.COMPARE_AND_SET),
181         GET_AND_SET(TestAccessType.GET_AND_SET),
182         GET_AND_SET_ACQUIRE(TestAccessType.GET_AND_SET),
183         GET_AND_SET_RELEASE(TestAccessType.GET_AND_SET),
184         GET_AND_ADD(TestAccessType.GET_AND_ADD),
185         GET_AND_ADD_ACQUIRE(TestAccessType.GET_AND_ADD),
186         GET_AND_ADD_RELEASE(TestAccessType.GET_AND_ADD),
187         GET_AND_BITWISE_OR(TestAccessType.GET_AND_BITWISE),
188         GET_AND_BITWISE_OR_ACQUIRE(TestAccessType.GET_AND_BITWISE),
189         GET_AND_BITWISE_OR_RELEASE(TestAccessType.GET_AND_BITWISE),
190         GET_AND_BITWISE_AND(TestAccessType.GET_AND_BITWISE),
191         GET_AND_BITWISE_AND_ACQUIRE(TestAccessType.GET_AND_BITWISE),
192         GET_AND_BITWISE_AND_RELEASE(TestAccessType.GET_AND_BITWISE),
193         GET_AND_BITWISE_XOR(TestAccessType.GET_AND_BITWISE),
194         GET_AND_BITWISE_XOR_ACQUIRE(TestAccessType.GET_AND_BITWISE),
195         GET_AND_BITWISE_XOR_RELEASE(TestAccessType.GET_AND_BITWISE),
196         ;
197 
198         final TestAccessType at;
199         final boolean isPolyMorphicInReturnType;
200         final Class<?> returnType;
201 
202         TestAccessMode(TestAccessType at) {
203             this.at = at;
204 
205             try {
206                 VarHandle.AccessMode vh_am = toAccessMode();
207                 Method m = VarHandle.class.getMethod(vh_am.methodName(), Object[].class);
208                 this.returnType = m.getReturnType();
209                 isPolyMorphicInReturnType = returnType != Object.class;
210             }
211             catch (Exception e) {
212                 throw new Error(e);
213             }
214         }
215 
216         boolean isOfType(TestAccessType at) {
217             return this.at == at;
218         }
219 
220         VarHandle.AccessMode toAccessMode() {
221             return VarHandle.AccessMode.valueOf(name());
222         }
223     }
224 
225     static List<TestAccessMode> testAccessModes() {
226         return Stream.of(TestAccessMode.values()).collect(toList());
227     }
228 
229     static List<TestAccessMode> testAccessModesOfType(TestAccessType... ats) {
230         Stream<TestAccessMode> s = Stream.of(TestAccessMode.values());
231         return s.filter(e -> Stream.of(ats).anyMatch(e::isOfType))
232                 .collect(toList());
233     }
234 
235     static List<VarHandle.AccessMode> accessModes() {
236         return Stream.of(VarHandle.AccessMode.values()).collect(toList());
237     }
238 
239     static List<VarHandle.AccessMode> accessModesOfType(TestAccessType... ats) {
240         Stream<TestAccessMode> s = Stream.of(TestAccessMode.values());
241         return s.filter(e -> Stream.of(ats).anyMatch(e::isOfType))
242                 .map(TestAccessMode::toAccessMode)
243                 .collect(toList());
244     }
245 
246     static MethodHandle toMethodHandle(VarHandle vh, TestAccessMode tam, MethodType mt) {
247         return vh.toMethodHandle(tam.toAccessMode());
248     }
249 
250     static MethodHandle findVirtual(VarHandle vh, TestAccessMode tam, MethodType mt) {
251         MethodHandle mh;
252         try {
253             mh = MethodHandles.publicLookup().
254                     findVirtual(VarHandle.class,
255                                 tam.toAccessMode().methodName(),
256                                 mt);
257         } catch (Exception e) {
258             throw new RuntimeException(e);
259         }
260         return bind(vh, mh, mt);
261     }
262 
263     static MethodHandle varHandleInvoker(VarHandle vh, TestAccessMode tam, MethodType mt) {
264         MethodHandle mh = MethodHandles.varHandleInvoker(
265                 tam.toAccessMode(),
266                 mt);
267 
268         return bind(vh, mh, mt);
269     }
270 
271     static MethodHandle varHandleExactInvoker(VarHandle vh, TestAccessMode tam, MethodType mt) {
272         MethodHandle mh = MethodHandles.varHandleExactInvoker(
273                 tam.toAccessMode(),
274                 mt);
275 
276         return bind(vh, mh, mt);
277     }
278 
279     private static MethodHandle bind(VarHandle vh, MethodHandle mh, MethodType emt) {
280         assertEquals(mh.type(), emt.insertParameterTypes(0, VarHandle.class),
281                      "MethodHandle type differs from access mode type");
282 
283         MethodHandleInfo info = MethodHandles.lookup().revealDirect(mh);
284         assertEquals(info.getMethodType(), emt,
285                      "MethodHandleInfo method type differs from access mode type");
286 
287         return mh.bindTo(vh);
288     }
289 
290     private interface TriFunction<T, U, V, R> {
291         R apply(T t, U u, V v);
292     }
293 
294     enum VarHandleToMethodHandle {
295         VAR_HANDLE_TO_METHOD_HANDLE(
296                 "VarHandle.toMethodHandle",
297                 true,
298                 VarHandleBaseTest::toMethodHandle),
299         METHOD_HANDLES_LOOKUP_FIND_VIRTUAL(
300                 "Lookup.findVirtual",
301                 false,
302                 VarHandleBaseTest::findVirtual),
303         METHOD_HANDLES_VAR_HANDLE_INVOKER(
304                 "MethodHandles.varHandleInvoker",
305                 false,
306                 VarHandleBaseTest::varHandleInvoker),
307         METHOD_HANDLES_VAR_HANDLE_EXACT_INVOKER(
308                 "MethodHandles.varHandleExactInvoker",
309                 true,
310                 VarHandleBaseTest::varHandleExactInvoker);
311 
312         final String desc;
313         final boolean isExact;
314         final TriFunction<VarHandle, TestAccessMode, MethodType, MethodHandle> f;
315 
316         VarHandleToMethodHandle(String desc, boolean isExact,
317                                 TriFunction<VarHandle, TestAccessMode, MethodType, MethodHandle> f) {
318             this.desc = desc;
319             this.f = f;
320             this.isExact = isExact;
321         }
322 
323         MethodHandle apply(VarHandle vh, TestAccessMode am, MethodType mt) {
324             return f.apply(vh, am, mt);
325         }
326 
327         @Override
328         public String toString() {
329             return desc;
330         }
331     }
332 
333     static class Handles {
334         static class AccessModeAndType {
335             final TestAccessMode tam;
336             final MethodType t;
337 
338             public AccessModeAndType(TestAccessMode tam, MethodType t) {
339                 this.tam = tam;
340                 this.t = t;
341             }
342 
343             @Override
344             public boolean equals(Object o) {
345                 if (this == o) return true;
346                 if (o == null || getClass() != o.getClass()) return false;
347 
348                 AccessModeAndType x = (AccessModeAndType) o;
349 
350                 if (tam != x.tam) return false;
351                 if (t != null ? !t.equals(x.t) : x.t != null) return false;
352 
353                 return true;
354             }
355 
356             @Override
357             public int hashCode() {
358                 int result = tam != null ? tam.hashCode() : 0;
359                 result = 31 * result + (t != null ? t.hashCode() : 0);
360                 return result;
361             }
362         }
363 
364         final VarHandle vh;
365         final VarHandleToMethodHandle f;
366         final EnumMap<TestAccessMode, MethodType> amToType;
367         final Map<AccessModeAndType, MethodHandle> amToHandle;
368 
369         Handles(VarHandle vh, VarHandleToMethodHandle f) throws Exception {
370             this.vh = vh;
371             this.f = f;
372             this.amToHandle = new HashMap<>();
373 
374             amToType = new EnumMap<>(TestAccessMode.class);
375             for (TestAccessMode am : testAccessModes()) {
376                 amToType.put(am, vh.accessModeType(am.toAccessMode()));
377             }
378         }
379 
380         MethodHandle get(TestAccessMode am) {
381             return get(am, amToType.get(am));
382         }
383 
384         MethodHandle get(TestAccessMode am, MethodType mt) {
385             AccessModeAndType amt = new AccessModeAndType(am, mt);
386             return amToHandle.computeIfAbsent(
387                     amt, k -> f.apply(vh, am, mt));
388         }
389 
390         Class<? extends Throwable> getWMTEOOrOther(Class<? extends Throwable> c) {
391             return f.isExact ? WrongMethodTypeException.class : c;
392         }
393 
394         void checkWMTEOrCCE(ThrowingRunnable r) {
395             checkWithThrowable(getWMTEOOrOther(ClassCastException.class), null, r);
396         }
397 
398     }
399 
400     interface AccessTestAction<T> {
401         void action(T t) throws Throwable;
402     }
403 
404     static abstract class AccessTestCase<T> {
405         final String desc;
406         final AccessTestAction<T> ata;
407         final boolean loop;
408 
409         AccessTestCase(String desc, AccessTestAction<T> ata, boolean loop) {
410             this.desc = desc;
411             this.ata = ata;
412             this.loop = loop;
413         }
414 
415         boolean requiresLoop() {
416             return loop;
417         }
418 
419         abstract T get() throws Exception;
420 
421         void testAccess(T t) throws Throwable {
422             ata.action(t);
423         }
424 
425         @Override
426         public String toString() {
427             return desc;
428         }
429     }
430 
431     static class VarHandleAccessTestCase extends AccessTestCase<VarHandle> {
432         final VarHandle vh;
433 
434         VarHandleAccessTestCase(String desc, VarHandle vh, AccessTestAction<VarHandle> ata) {
435             this(desc, vh, ata, true);
436         }
437 
438         VarHandleAccessTestCase(String desc, VarHandle vh, AccessTestAction<VarHandle> ata, boolean loop) {
439             super("VarHandle -> " + desc, ata, loop);
440             this.vh = vh;
441         }
442 
443         @Override
444         VarHandle get() {
445             return vh;
446         }
447 
448         public String toString() {
449             return super.toString() + ", vh:" + vh;
450         }
451     }
452 
453     static class MethodHandleAccessTestCase extends AccessTestCase<Handles> {
454         final VarHandle vh;
455         final VarHandleToMethodHandle f;
456 
457         MethodHandleAccessTestCase(String desc, VarHandle vh, VarHandleToMethodHandle f, AccessTestAction<Handles> ata) {
458             this(desc, vh, f, ata, true);
459         }
460 
461         MethodHandleAccessTestCase(String desc, VarHandle vh, VarHandleToMethodHandle f, AccessTestAction<Handles> ata, boolean loop) {
462             super("VarHandle -> " + f.toString() + " -> " + desc, ata, loop);
463             this.vh = vh;
464             this.f = f;
465         }
466 
467         @Override
468         Handles get() throws Exception {
469             return new Handles(vh, f);
470         }
471 
472         public String toString() {
473             return super.toString() + ", vh:" + vh + ", f: " + f;
474         }
475     }
476 
477     static void testTypes(VarHandle vh) {
478         List<Class<?>> pts = vh.coordinateTypes();
479 
480         for (TestAccessMode accessMode : testAccessModes()) {
481             MethodType amt = vh.accessModeType(accessMode.toAccessMode());
482 
483             assertEquals(amt.parameterList().subList(0, pts.size()), pts);
484         }
485 
486         for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.GET)) {
487             MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
488             assertEquals(mt.returnType(), vh.varType());
489             assertEquals(mt.parameterList(), pts);
490         }
491 
492         for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.SET)) {
493             MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
494             assertEquals(mt.returnType(), void.class);
495             assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
496         }
497 
498         for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.COMPARE_AND_SET)) {
499             MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
500             assertEquals(mt.returnType(), boolean.class);
501             assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
502             assertEquals(mt.parameterType(mt.parameterCount() - 2), vh.varType());
503         }
504 
505         for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.COMPARE_AND_EXCHANGE)) {
506             MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
507             assertEquals(mt.returnType(), vh.varType());
508             assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
509             assertEquals(mt.parameterType(mt.parameterCount() - 2), vh.varType());
510         }
511 
512         for (TestAccessMode testAccessMode : testAccessModesOfType(TestAccessType.GET_AND_SET, TestAccessType.GET_AND_ADD)) {
513             MethodType mt = vh.accessModeType(testAccessMode.toAccessMode());
514             assertEquals(mt.returnType(), vh.varType());
515             assertEquals(mt.parameterType(mt.parameterCount() - 1), vh.varType());
516         }
517     }
518 
519     static void weakDelay() {
520         try {
521             if (WEAK_DELAY_MS > 0) {
522                 Thread.sleep(WEAK_DELAY_MS);
523             }
524         } catch (InterruptedException ie) {
525             // Do nothing.
526         }
527     }
528 }