1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 import static org.junit.jupiter.api.Assertions.*;
 27 
 28 import jdk.internal.value.ValueClass;
 29 import jdk.internal.vm.annotation.LooselyConsistentValue;
 30 import jdk.internal.vm.annotation.NullRestricted;
 31 import jdk.internal.vm.annotation.Strict;
 32 import org.junit.jupiter.params.provider.Arguments;
 33 import org.junit.jupiter.params.provider.MethodSource;
 34 import org.junit.jupiter.params.ParameterizedTest;
 35 
 36 import java.lang.invoke.MethodHandle;
 37 import java.lang.invoke.MethodHandles;
 38 import java.lang.invoke.VarHandle;
 39 import java.lang.invoke.VarHandle.AccessMode;
 40 import java.lang.reflect.Field;
 41 import java.lang.reflect.Modifier;
 42 import java.util.ArrayList;
 43 import java.util.List;
 44 import java.util.Objects;
 45 import java.util.function.BiFunction;
 46 import java.util.function.Function;
 47 
 48 // TODO: 8353180: Remove requires != Xcomp
 49 
 50 /*
 51  * @test
 52  * @requires vm.compMode != "Xcomp"
 53  * @summary Test atomic access modes on var handles for flattened values
 54  * @enablePreview
 55  * @modules java.base/jdk.internal.value java.base/jdk.internal.vm.annotation
 56  * @run junit/othervm -XX:-UseArrayFlattening -XX:-UseNullableValueFlattening FlatVarHandleTest
 57  * @run junit/othervm -XX:+UseArrayFlattening -XX:+UseNullableValueFlattening FlatVarHandleTest
 58  */
 59 public class FlatVarHandleTest {
 60 
 61     interface Pointable { }
 62 
 63     @FunctionalInterface
 64     interface TriFunction<A, B, C, R> {
 65 
 66         R apply(A a, B b, C c);
 67 
 68         default <V> TriFunction<A, B, C, V> andThen(
 69                 Function<? super R, ? extends V> after) {
 70             Objects.requireNonNull(after);
 71             return (A a, B b, C c) -> after.apply(apply(a, b, c));
 72         }
 73     }
 74 
 75     @LooselyConsistentValue
 76     static value class WeakPoint implements Pointable {
 77         short x,y;
 78         WeakPoint(int i, int j) { x = (short)i; y = (short)j; }
 79 
 80         static WeakPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
 81             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len);
 82             for (int i = 0; i < len; ++i) {
 83                 array[i] = new WeakPoint(i, i);
 84             }
 85             return array;
 86         }
 87 
 88         static WeakPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
 89             WeakPoint[] array = (WeakPoint[])arrayFactory.apply(WeakPoint.class, len, initval);
 90             for (int i = 0; i < len; ++i) {
 91                 array[i] = new WeakPoint(i, i);
 92             }
 93             return array;
 94         }
 95     }
 96 
 97     static class WeakPointHolder {
 98         WeakPoint p_i = new WeakPoint(0, 0);
 99         static WeakPoint p_s = new WeakPoint(0, 0);
100         @Strict
101         @NullRestricted
102         WeakPoint p_i_nr = new WeakPoint(0, 0);
103         @Strict
104         @NullRestricted
105         static WeakPoint p_s_nr = new WeakPoint(0, 0);
106     }
107 
108     static value class StrongPoint implements Pointable {
109         short x,y;
110         StrongPoint(int i, int j) { x = (short)i; y = (short)j; }
111 
112         static StrongPoint[] makePoints(int len, BiFunction<Class<?>, Integer, Object[]> arrayFactory) {
113             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len);
114             for (int i = 0; i < len; ++i) {
115                 array[i] = new StrongPoint(i, i);
116             }
117             return array;
118         }
119 
120         static StrongPoint[] makePoints(int len, Object initval, TriFunction<Class<?>, Integer, Object, Object[]> arrayFactory) {
121             StrongPoint[] array = (StrongPoint[])arrayFactory.apply(StrongPoint.class, len, initval);
122             for (int i = 0; i < len; ++i) {
123                 array[i] = new StrongPoint(i, i);
124             }
125             return array;
126         }
127     }
128 
129     static class StrongPointHolder {
130         StrongPoint p_i = new StrongPoint(0, 0);
131         static StrongPoint p_s = new StrongPoint(0, 0);
132     }
133 
134     private static List<Arguments> fieldAccessProvider() {
135         try {
136             List<Field> fields = List.of(
137                     WeakPointHolder.class.getDeclaredField("p_s"),
138                     WeakPointHolder.class.getDeclaredField("p_i"),
139                     WeakPointHolder.class.getDeclaredField("p_s_nr"),
140                     WeakPointHolder.class.getDeclaredField("p_i_nr"),
141                     StrongPointHolder.class.getDeclaredField("p_s"),
142                     StrongPointHolder.class.getDeclaredField("p_i"));
143             List<Arguments> arguments = new ArrayList<>();
144             for (AccessMode accessMode : AccessMode.values()) {
145                 for (Field field : fields) {
146                     boolean isStatic = (field.getModifiers() & Modifier.STATIC) != 0;
147                     boolean isWeak = field.getDeclaringClass().equals(WeakPointHolder.class);
148                     Object holder = null;
149                     if (!isStatic) {
150                         holder = isWeak ? new WeakPointHolder() : new StrongPointHolder();
151                     }
152                     BiFunction<Integer, Integer, Object> factory = isWeak ?
153                             (i1, i2) -> new WeakPoint(i1, i2) :
154                             (i1, i2) -> new StrongPoint(i1, i2);
155                     boolean allowsNonPlainAccess = (field.getModifiers() & Modifier.VOLATILE) != 0 ||
156                             !ValueClass.isNullRestrictedField(field) ||
157                             !isWeak;
158                     arguments.add(Arguments.of(accessMode, holder, factory, field, allowsNonPlainAccess));
159                 }
160             }
161             return arguments;
162         } catch (ReflectiveOperationException ex) {
163             throw new IllegalStateException(ex);
164         }
165     }
166 
167     /*
168      * Verify that atomic access modes are not supported on flat fields.
169      */
170     @ParameterizedTest
171     @MethodSource("fieldAccessProvider")
172     public void testFieldAccess(AccessMode accessMode, Object holder, BiFunction<Integer, Integer, Object> factory,
173                                 Field field, boolean allowsNonPlainAccess) throws Throwable {
174         VarHandle varHandle = MethodHandles.lookup().unreflectVarHandle(field);
175         if (varHandle.isAccessModeSupported(accessMode)) {
176             assertTrue(isPlain(accessMode) || (allowsNonPlainAccess && !isBitwise(accessMode) && !isNumeric(accessMode)));
177             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
178             List<Object> arguments = new ArrayList<>();
179             if (holder != null) {
180                 arguments.add(holder); // receiver
181             }
182             for (int i = arguments.size(); i < methodHandle.type().parameterCount(); i++) {
183                 arguments.add(factory.apply(i, i)); // add extra setter param
184             }
185             methodHandle.invokeWithArguments(arguments.toArray());
186         } else {
187             assertTrue(!allowsNonPlainAccess || isBitwise(accessMode) || isNumeric(accessMode));
188         }
189     }
190 
191     private static List<Arguments> arrayAccessProvider() {
192         List<Object[]> arrayObjects = List.of(
193                 WeakPoint.makePoints(10, ValueClass::newNullableAtomicArray),
194                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
195                 WeakPoint.makePoints(10, new WeakPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
196                 new WeakPoint[10],
197                 StrongPoint.makePoints(10, ValueClass::newNullableAtomicArray),
198                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedNonAtomicArray),
199                 StrongPoint.makePoints(10, new StrongPoint(0, 0), ValueClass::newNullRestrictedAtomicArray),
200                 new StrongPoint[10]);
201 
202         List<Arguments> arguments = new ArrayList<>();
203         for (AccessMode accessMode : AccessMode.values()) {
204             if (accessMode.ordinal() != 2) continue;
205             for (Object[] arrayObject : arrayObjects) {
206                 boolean isWeak = arrayObject.getClass().getComponentType().equals(WeakPoint.class);
207                 List<Class<?>> arrayTypes = List.of(
208                         isWeak ? WeakPoint[].class : StrongPoint[].class, Pointable[].class, Object[].class);
209                 for (Class<?> arrayType : arrayTypes) {
210                     BiFunction<Integer, Integer, Object> factory = isWeak ?
211                             (i1, i2) -> new WeakPoint(i1, i2) :
212                             (i1, i2) -> new StrongPoint((short)(int)i1, (short)(int)i2);
213                     boolean allowsNonPlainAccess = !ValueClass.isNullRestrictedArray(arrayObject) ||
214                             ValueClass.isAtomicArray(arrayObject) ||
215                             !isWeak;
216                     arguments.add(Arguments.of(accessMode, arrayObject, factory, arrayType, allowsNonPlainAccess));
217                 }
218             }
219         }
220         return arguments;
221     }
222 
223     /*
224      * Verify that atomic access modes are not supported on flat array instances.
225      */
226     @ParameterizedTest
227     @MethodSource("arrayAccessProvider")
228     public void testArrayAccess(AccessMode accessMode, Object[] arrayObject, BiFunction<Integer, Integer, Object> factory,
229                                 Class<?> arrayType, boolean allowsNonPlainAccess) throws Throwable {
230         VarHandle varHandle = MethodHandles.arrayElementVarHandle(arrayType);
231         if (varHandle.isAccessModeSupported(accessMode)) {
232             assertTrue(!isBitwise(accessMode) && !isNumeric(accessMode));
233             MethodHandle methodHandle = varHandle.toMethodHandle(accessMode);
234             List<Object> arguments = new ArrayList<>();
235             arguments.add(arrayObject); // array receiver
236             arguments.add(0); // index
237             for (int i = 2; i < methodHandle.type().parameterCount(); i++) {
238                 arguments.add(factory.apply(i, i)); // add extra setter param
239             }
240             try {
241                 methodHandle.invokeWithArguments(arguments.toArray());
242             } catch (IllegalArgumentException ex) {
243                 assertFalse(allowsNonPlainAccess);
244             }
245         } else {
246             assertTrue(isBitwise(accessMode) || isNumeric(accessMode));
247         }
248     }
249 
250     boolean isBitwise(AccessMode accessMode) {
251         return switch (accessMode) {
252             case GET_AND_BITWISE_AND, GET_AND_BITWISE_AND_ACQUIRE,
253                  GET_AND_BITWISE_AND_RELEASE, GET_AND_BITWISE_OR,
254                  GET_AND_BITWISE_OR_ACQUIRE, GET_AND_BITWISE_OR_RELEASE,
255                  GET_AND_BITWISE_XOR, GET_AND_BITWISE_XOR_ACQUIRE,
256                  GET_AND_BITWISE_XOR_RELEASE -> true;
257             default -> false;
258         };
259     }
260 
261     boolean isNumeric(AccessMode accessMode) {
262         return switch (accessMode) {
263             case GET_AND_ADD, GET_AND_ADD_ACQUIRE, GET_AND_ADD_RELEASE -> true;
264             default -> false;
265         };
266     }
267 
268     boolean isPlain(AccessMode accessMode) {
269         return switch (accessMode) {
270             case GET, SET -> true;
271             default -> false;
272         };
273     }
274 }