1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 import static org.junit.jupiter.api.Assertions.*;
 27 
 28 import jdk.internal.value.ValueClass;
 29 import jdk.internal.vm.annotation.LooselyConsistentValue;
 30 import jdk.internal.vm.annotation.NullRestricted;
 31 import jdk.internal.vm.annotation.Strict;
 32 import org.junit.jupiter.params.provider.Arguments;
 33 import org.junit.jupiter.params.provider.MethodSource;
 34 import org.junit.jupiter.params.ParameterizedTest;
 35 
 36 import java.lang.invoke.MethodHandle;
 37 import java.lang.invoke.MethodHandles;
 38 import java.lang.invoke.VarHandle;
 39 import java.lang.invoke.VarHandle.AccessMode;
 40 import java.lang.reflect.Field;
 41 import java.lang.reflect.Modifier;
 42 import java.util.ArrayList;
 43 import java.util.List;
 44 import java.util.Objects;
 45 import java.util.function.BiFunction;
 46 import java.util.function.Function;
 47 
 48 /*
 49  * @test
 50  * @summary Test atomic access modes on var handles for flattened values
 51  * @enablePreview
 52  * @modules java.base/jdk.internal.value java.base/jdk.internal.vm.annotation
 53  * @run junit/othervm -XX:-UseArrayFlattening -XX:-UseNullableValueFlattening FlatVarHandleTest
 54  * @run junit/othervm -XX:+UseArrayFlattening -XX:+UseNullableValueFlattening FlatVarHandleTest
 55  */
 56 public class FlatVarHandleTest {
 57 
 58     interface Pointable { }
 59 
 60     @FunctionalInterface
 61     interface TriFunction<A, B, C, R> {
 62 
 63         R apply(A a, B b, C c);
 64 
 65         default <V> TriFunction<A, B, C, V> andThen(
 66                 Function<? super R, ? extends V> after) {
 67             Objects.requireNonNull(after);
 68             return (A a, B b, C c) -> after.apply(apply(a, b, c));
 69         }
 70     }
 71 
 72     @LooselyConsistentValue
 73     static value class WeakPoint implements Pointable {
 74         short x,y;
 75         WeakPoint(int i, int j) { x = (short)i; y = (short)j; }
 76 
 77         static WeakPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
 78             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len);
 79             for (int i = 0; i < len; ++i) {
 80                 array[i] = new WeakPoint(i, i);
 81             }
 82             return array;
 83         }
 84 
 85         static WeakPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
 86             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len, initval);
 87             for (int i = 0; i < len; ++i) {
 88                 array[i] = new WeakPoint(i, i);
 89             }
 90             return array;
 91         }
 92     }
 93 
 94     static class WeakPointHolder {
 95         WeakPoint p_i = new WeakPoint(0, 0);
 96         static WeakPoint p_s = new WeakPoint(0, 0);
 97         @Strict
 98         @NullRestricted
 99         WeakPoint p_i_nr = new WeakPoint(0, 0);
100         @Strict
101         @NullRestricted
102         static WeakPoint p_s_nr = new WeakPoint(0, 0);
103     }
104 
105     static value class StrongPoint implements Pointable {
106         short x,y;
107         StrongPoint(int i, int j) { x = (short)i; y = (short)j; }
108 
109         static StrongPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
110             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len);
111             for (int i = 0; i < len; ++i) {
112                 array[i] = new StrongPoint(i, i);
113             }
114             return array;
115         }
116 
117         static StrongPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
118             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len, initval);
119             for (int i = 0; i < len; ++i) {
120                 array[i] = new StrongPoint(i, i);
121             }
122             return array;
123         }
124     }
125 
126     static class StrongPointHolder {
127         StrongPoint p_i = new StrongPoint(0, 0);
128         static StrongPoint p_s = new StrongPoint(0, 0);
129     }
130 
131     private static List<Arguments> fieldAccessProvider() {
132         try {
133             List<Field> fields = List.of(
134                     WeakPointHolder.class.getDeclaredField("p_s"),
135                     WeakPointHolder.class.getDeclaredField("p_i"),
136                     WeakPointHolder.class.getDeclaredField("p_s_nr"),
137                     WeakPointHolder.class.getDeclaredField("p_i_nr"),
138                     StrongPointHolder.class.getDeclaredField("p_s"),
139                     StrongPointHolder.class.getDeclaredField("p_i"));
140             List<Arguments> arguments = new ArrayList<>();
141             for (AccessMode accessMode : AccessMode.values()) {
142                 for (Field field : fields) {
143                     boolean isStatic = (field.getModifiers() & Modifier.STATIC) != 0;
144                     boolean isWeak = field.getDeclaringClass().equals(WeakPointHolder.class);
145                     Object holder = null;
146                     if (!isStatic) {
147                         holder = isWeak ? new WeakPointHolder() : new StrongPointHolder();
148                     }
149                     BiFunction<Integer, Integer, Object> factory = isWeak ?
150                             (i1, i2) -> new WeakPoint(i1, i2) :
151                             (i1, i2) -> new StrongPoint(i1, i2);
152                     boolean allowsNonPlainAccess = (field.getModifiers() & Modifier.VOLATILE) != 0 ||
153                             !ValueClass.isNullRestrictedField(field) ||
154                             !isWeak;
155                     arguments.add(Arguments.of(accessMode, holder, factory, field, allowsNonPlainAccess));
156                 }
157             }
158             return arguments;
159         } catch (ReflectiveOperationException ex) {
160             throw new IllegalStateException(ex);
161         }
162     }
163 
164     /*
165      * Verify that atomic access modes are not supported on flat fields.
166      */
167     @ParameterizedTest
168     @MethodSource("fieldAccessProvider")
169     public void testFieldAccess(AccessMode accessMode, Object holder, BiFunction<Integer, Integer, Object> factory,
170                                 Field field, boolean allowsNonPlainAccess) throws Throwable {
171         VarHandle varHandle = MethodHandles.lookup().unreflectVarHandle(field);
172         if (varHandle.isAccessModeSupported(accessMode)) {
173             assertTrue(isPlain(accessMode) || (allowsNonPlainAccess && !isBitwise(accessMode) && !isNumeric(accessMode)));
174             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
175             List<Object> arguments = new ArrayList<>();
176             if (holder != null) {
177                 arguments.add(holder); // receiver
178             }
179             for (int i = arguments.size(); i < methodHandle.type().parameterCount(); i++) {
180                 arguments.add(factory.apply(i, i)); // add extra setter param
181             }
182             methodHandle.invokeWithArguments(arguments.toArray());
183         } else {
184             assertTrue(!allowsNonPlainAccess || isBitwise(accessMode) || isNumeric(accessMode));
185         }
186     }
187 
188     private static List<Arguments> arrayAccessProvider() {
189         List<Object[]> arrayObjects = List.of(
190                 WeakPoint.makePoints(10, ValueClass::newNullableAtomicArray),
191                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
192                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
193                 new WeakPoint[10],
194                 StrongPoint.makePoints(10, ValueClass::newNullableAtomicArray),
195                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
196                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
197                 new StrongPoint[10]);
198 
199         List<Arguments> arguments = new ArrayList<>();
200         for (AccessMode accessMode : AccessMode.values()) {
201             if (accessMode.ordinal() != 2) continue;
202             for (Object[] arrayObject : arrayObjects) {
203                 boolean isWeak = arrayObject.getClass().getComponentType().equals(WeakPoint.class);
204                 List<Class<?>> arrayTypes = List.of(
205                         isWeak ? WeakPoint[].class : StrongPoint[].class, Pointable[].class, Object[].class);
206                 for (Class<?> arrayType : arrayTypes) {
207                     BiFunction<Integer, Integer, Object> factory = isWeak ?
208                             (i1, i2) -> new WeakPoint(i1, i2) :
209                             (i1, i2) -> new StrongPoint((short)(int)i1, (short)(int)i2);
210                     boolean allowsNonPlainAccess = !ValueClass.isNullRestrictedArray(arrayObject) ||
211                             ValueClass.isAtomicArray(arrayObject) ||
212                             !isWeak;
213                     arguments.add(Arguments.of(accessMode, arrayObject, factory, arrayType, allowsNonPlainAccess));
214                 }
215             }
216         }
217         return arguments;
218     }
219 
220     /*
221      * Verify that atomic access modes are not supported on flat array instances.
222      */
223     @ParameterizedTest
224     @MethodSource("arrayAccessProvider")
225     public void testArrayAccess(AccessMode accessMode, Object[] arrayObject, BiFunction<Integer, Integer, Object> factory,
226                                 Class<?> arrayType, boolean allowsNonPlainAccess) throws Throwable {
227         VarHandle varHandle = MethodHandles.arrayElementVarHandle(arrayType);
228         if (varHandle.isAccessModeSupported(accessMode)) {
229             assertTrue(!isBitwise(accessMode) && !isNumeric(accessMode));
230             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
231             List<Object> arguments = new ArrayList<>();
232             arguments.add(arrayObject); // array receiver
233             arguments.add(0); // index
234             for (int i = 2; i < methodHandle.type().parameterCount(); i++) {
235                 arguments.add(factory.apply(i, i)); // add extra setter param
236             }
237             try {
238                 methodHandle.invokeWithArguments(arguments.toArray());
239             } catch (IllegalArgumentException ex) {
240                 assertFalse(allowsNonPlainAccess);
241             }
242         } else {
243             assertTrue(isBitwise(accessMode) || isNumeric(accessMode));
244         }
245     }
246 
247     boolean isBitwise(AccessMode accessMode) {
248         return switch (accessMode) {
249             case GET_AND_BITWISE_AND, GET_AND_BITWISE_AND_ACQUIRE,
250                  GET_AND_BITWISE_AND_RELEASE, GET_AND_BITWISE_OR,
251                  GET_AND_BITWISE_OR_ACQUIRE, GET_AND_BITWISE_OR_RELEASE,
252                  GET_AND_BITWISE_XOR, GET_AND_BITWISE_XOR_ACQUIRE,
253                  GET_AND_BITWISE_XOR_RELEASE -> true;
254             default -> false;
255         };
256     }
257 
258     boolean isNumeric(AccessMode accessMode) {
259         return switch (accessMode) {
260             case GET_AND_ADD, GET_AND_ADD_ACQUIRE, GET_AND_ADD_RELEASE -> true;
261             default -> false;
262         };
263     }
264 
265     boolean isPlain(AccessMode accessMode) {
266         return switch (accessMode) {
267             case GET, SET -> true;
268             default -> false;
269         };
270     }
271 }