1 /*
  2  * Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 import static org.junit.jupiter.api.Assertions.*;
 27 
 28 import jdk.internal.value.ValueClass;
 29 import jdk.internal.vm.annotation.LooselyConsistentValue;
 30 import jdk.internal.vm.annotation.NullRestricted;
 31 import org.junit.jupiter.params.provider.Arguments;
 32 import org.junit.jupiter.params.provider.MethodSource;
 33 import org.junit.jupiter.params.ParameterizedTest;
 34 
 35 import java.lang.invoke.MethodHandle;
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.invoke.VarHandle;
 38 import java.lang.invoke.VarHandle.AccessMode;
 39 import java.lang.reflect.Field;
 40 import java.lang.reflect.Modifier;
 41 import java.util.ArrayList;
 42 import java.util.List;
 43 import java.util.Objects;
 44 import java.util.function.BiFunction;
 45 import java.util.function.Function;
 46 
 47 /*
 48  * @test
 49  * @summary Test atomic access modes on var handles for flattened values
 50  * @enablePreview
 51  * @modules java.base/jdk.internal.value java.base/jdk.internal.vm.annotation
 52  * @run junit/othervm -XX:-UseArrayFlattening -XX:-UseNullableValueFlattening FlatVarHandleTest
 53  * @run junit/othervm -XX:+UseArrayFlattening -XX:+UseNullableValueFlattening FlatVarHandleTest
 54  */
 55 public class FlatVarHandleTest {
 56 
 57     interface Pointable { }
 58 
 59     @FunctionalInterface
 60     interface TriFunction<A, B, C, R> {
 61 
 62         R apply(A a, B b, C c);
 63 
 64         default <V> TriFunction<A, B, C, V> andThen(
 65                 Function<? super R, ? extends V> after) {
 66             Objects.requireNonNull(after);
 67             return (A a, B b, C c) -> after.apply(apply(a, b, c));
 68         }
 69     }
 70 
 71     @LooselyConsistentValue
 72     static value class WeakPoint implements Pointable {
 73         short x,y;
 74         WeakPoint(int i, int j) { x = (short)i; y = (short)j; }
 75 
 76         static WeakPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
 77             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len);
 78             for (int i = 0; i < len; ++i) {
 79                 array[i] = new WeakPoint(i, i);
 80             }
 81             return array;
 82         }
 83 
 84         static WeakPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
 85             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len, initval);
 86             for (int i = 0; i < len; ++i) {
 87                 array[i] = new WeakPoint(i, i);
 88             }
 89             return array;
 90         }
 91     }
 92 
 93     static class WeakPointHolder {
 94         WeakPoint p_i = new WeakPoint(0, 0);
 95         static WeakPoint p_s = new WeakPoint(0, 0);
 96         @NullRestricted
 97         WeakPoint p_i_nr;
 98         @NullRestricted
 99         static WeakPoint p_s_nr = new WeakPoint(0, 0);
100 
101         WeakPointHolder() {
102             p_i_nr = new WeakPoint(0, 0);
103             super();
104         }
105     }
106 
107     static value class StrongPoint implements Pointable {
108         short x,y;
109         StrongPoint(int i, int j) { x = (short)i; y = (short)j; }
110 
111         static StrongPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
112             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len);
113             for (int i = 0; i < len; ++i) {
114                 array[i] = new StrongPoint(i, i);
115             }
116             return array;
117         }
118 
119         static StrongPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
120             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len, initval);
121             for (int i = 0; i < len; ++i) {
122                 array[i] = new StrongPoint(i, i);
123             }
124             return array;
125         }
126     }
127 
128     static class StrongPointHolder {
129         StrongPoint p_i = new StrongPoint(0, 0);
130         static StrongPoint p_s = new StrongPoint(0, 0);
131     }
132 
133     private static List<Arguments> fieldAccessProvider() {
134         try {
135             List<Field> fields = List.of(
136                     WeakPointHolder.class.getDeclaredField("p_s"),
137                     WeakPointHolder.class.getDeclaredField("p_i"),
138                     WeakPointHolder.class.getDeclaredField("p_s_nr"),
139                     WeakPointHolder.class.getDeclaredField("p_i_nr"),
140                     StrongPointHolder.class.getDeclaredField("p_s"),
141                     StrongPointHolder.class.getDeclaredField("p_i"));
142             List<Arguments> arguments = new ArrayList<>();
143             for (AccessMode accessMode : AccessMode.values()) {
144                 for (Field field : fields) {
145                     boolean isStatic = (field.getModifiers() & Modifier.STATIC) != 0;
146                     boolean isWeak = field.getDeclaringClass().equals(WeakPointHolder.class);
147                     Object holder = null;
148                     if (!isStatic) {
149                         holder = isWeak ? new WeakPointHolder() : new StrongPointHolder();
150                     }
151                     BiFunction<Integer, Integer, Object> factory = isWeak ?
152                             (i1, i2) -> new WeakPoint(i1, i2) :
153                             (i1, i2) -> new StrongPoint(i1, i2);
154                     boolean allowsNonPlainAccess = (field.getModifiers() & Modifier.VOLATILE) != 0 ||
155                             !ValueClass.isNullRestrictedField(field) ||
156                             !isWeak;
157                     arguments.add(Arguments.of(accessMode, holder, factory, field, allowsNonPlainAccess));
158                 }
159             }
160             return arguments;
161         } catch (ReflectiveOperationException ex) {
162             throw new IllegalStateException(ex);
163         }
164     }
165 
166     /*
167      * Verify that atomic access modes are not supported on flat fields.
168      */
169     @ParameterizedTest
170     @MethodSource("fieldAccessProvider")
171     public void testFieldAccess(AccessMode accessMode, Object holder, BiFunction<Integer, Integer, Object> factory,
172                                 Field field, boolean allowsNonPlainAccess) throws Throwable {
173         VarHandle varHandle = MethodHandles.lookup().unreflectVarHandle(field);
174         if (varHandle.isAccessModeSupported(accessMode)) {
175             assertTrue(isPlain(accessMode) || (allowsNonPlainAccess && !isBitwise(accessMode) && !isNumeric(accessMode)));
176             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
177             List<Object> arguments = new ArrayList<>();
178             if (holder != null) {
179                 arguments.add(holder); // receiver
180             }
181             for (int i = arguments.size(); i < methodHandle.type().parameterCount(); i++) {
182                 arguments.add(factory.apply(i, i)); // add extra setter param
183             }
184             methodHandle.invokeWithArguments(arguments.toArray());
185         } else {
186             assertTrue(!allowsNonPlainAccess || isBitwise(accessMode) || isNumeric(accessMode));
187         }
188     }
189 
190     private static List<Arguments> arrayAccessProvider() {
191         List<Object[]> arrayObjects = List.of(
192                 WeakPoint.makePoints(10, ValueClass::newNullableAtomicArray),
193                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
194                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
195                 new WeakPoint[10],
196                 StrongPoint.makePoints(10, ValueClass::newNullableAtomicArray),
197                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
198                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
199                 new StrongPoint[10]);
200 
201         List<Arguments> arguments = new ArrayList<>();
202         for (AccessMode accessMode : AccessMode.values()) {
203             if (accessMode.ordinal() != 2) continue;
204             for (Object[] arrayObject : arrayObjects) {
205                 boolean isWeak = arrayObject.getClass().getComponentType().equals(WeakPoint.class);
206                 List<Class<?>> arrayTypes = List.of(
207                         isWeak ? WeakPoint[].class : StrongPoint[].class, Pointable[].class, Object[].class);
208                 for (Class<?> arrayType : arrayTypes) {
209                     BiFunction<Integer, Integer, Object> factory = isWeak ?
210                             (i1, i2) -> new WeakPoint(i1, i2) :
211                             (i1, i2) -> new StrongPoint((short)(int)i1, (short)(int)i2);
212                     boolean allowsNonPlainAccess = !ValueClass.isNullRestrictedArray(arrayObject) ||
213                             ValueClass.isAtomicArray(arrayObject) ||
214                             !isWeak;
215                     arguments.add(Arguments.of(accessMode, arrayObject, factory, arrayType, allowsNonPlainAccess));
216                 }
217             }
218         }
219         return arguments;
220     }
221 
222     /*
223      * Verify that atomic access modes are not supported on flat array instances.
224      */
225     @ParameterizedTest
226     @MethodSource("arrayAccessProvider")
227     public void testArrayAccess(AccessMode accessMode, Object[] arrayObject, BiFunction<Integer, Integer, Object> factory,
228                                 Class<?> arrayType, boolean allowsNonPlainAccess) throws Throwable {
229         VarHandle varHandle = MethodHandles.arrayElementVarHandle(arrayType);
230         if (varHandle.isAccessModeSupported(accessMode)) {
231             assertTrue(!isBitwise(accessMode) && !isNumeric(accessMode));
232             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
233             List<Object> arguments = new ArrayList<>();
234             arguments.add(arrayObject); // array receiver
235             arguments.add(0); // index
236             for (int i = 2; i < methodHandle.type().parameterCount(); i++) {
237                 arguments.add(factory.apply(i, i)); // add extra setter param
238             }
239             try {
240                 methodHandle.invokeWithArguments(arguments.toArray());
241             } catch (IllegalArgumentException ex) {
242                 assertFalse(allowsNonPlainAccess);
243             }
244         } else {
245             assertTrue(isBitwise(accessMode) || isNumeric(accessMode));
246         }
247     }
248 
249     boolean isBitwise(AccessMode accessMode) {
250         return switch (accessMode) {
251             case GET_AND_BITWISE_AND, GET_AND_BITWISE_AND_ACQUIRE,
252                  GET_AND_BITWISE_AND_RELEASE, GET_AND_BITWISE_OR,
253                  GET_AND_BITWISE_OR_ACQUIRE, GET_AND_BITWISE_OR_RELEASE,
254                  GET_AND_BITWISE_XOR, GET_AND_BITWISE_XOR_ACQUIRE,
255                  GET_AND_BITWISE_XOR_RELEASE -> true;
256             default -> false;
257         };
258     }
259 
260     boolean isNumeric(AccessMode accessMode) {
261         return switch (accessMode) {
262             case GET_AND_ADD, GET_AND_ADD_ACQUIRE, GET_AND_ADD_RELEASE -> true;
263             default -> false;
264         };
265     }
266 
267     boolean isPlain(AccessMode accessMode) {
268         return switch (accessMode) {
269             case GET, SET -> true;
270             default -> false;
271         };
272     }
273 }