1 /*
  2  * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @modules java.base/jdk.internal.ref
 27  * @run testng/othervm
 28  *     --enable-native-access=ALL-UNNAMED
 29  *     TestNulls
 30  */
 31 
 32 import java.lang.foreign.*;
 33 
 34 import jdk.internal.ref.CleanerFactory;
 35 import org.testng.annotations.DataProvider;
 36 import org.testng.annotations.NoInjection;
 37 import org.testng.annotations.Test;
 38 
 39 import java.lang.constant.Constable;
 40 import java.lang.foreign.Arena;
 41 import java.lang.invoke.MethodHandle;
 42 import java.lang.invoke.MethodHandles;
 43 import java.lang.invoke.MethodType;
 44 import java.lang.invoke.VarHandle;
 45 import java.lang.ref.Cleaner;
 46 import java.lang.reflect.Array;
 47 import java.lang.reflect.InvocationTargetException;
 48 import java.lang.reflect.Method;
 49 import java.lang.reflect.Modifier;
 50 import java.nio.Buffer;
 51 import java.nio.ByteBuffer;
 52 import java.nio.ByteOrder;
 53 import java.nio.channels.FileChannel;
 54 import java.nio.charset.Charset;
 55 import java.nio.file.Path;
 56 import java.util.*;
 57 import java.util.function.Consumer;
 58 import java.util.function.Supplier;
 59 import java.util.function.UnaryOperator;
 60 import java.util.stream.Collectors;
 61 import java.util.stream.Stream;
 62 
 63 import static java.lang.foreign.ValueLayout.JAVA_INT;
 64 import static java.lang.foreign.ValueLayout.JAVA_LONG;
 65 import static org.testng.Assert.*;
 66 import static org.testng.Assert.fail;
 67 
 68 /**
 69  * This test makes sure that public API classes (listed in {@link TestNulls#CLASSES}) throws NPEs whenever
 70  * nulls are provided. The test looks at all the public methods in all the listed classes, and injects
 71  * values automatically. If an API takes a reference, the test will try to inject nulls. For APIs taking
 72  * either reference arrays, or collections, the framework will also generate additional <em>replacements</em>
 73  * (e.g. other than just replacing the array, or collection with null), such as an array or collection
 74  * with null elements. The test can be customized by adding/removing classes to the {@link #CLASSES} array,
 75  * by adding/removing default mappings for standard carrier types (see {@link #DEFAULT_VALUES} or by
 76  * adding/removing custom replacements (see {@link #REPLACEMENT_VALUES}).
 77  */
 78 public class TestNulls {
 79 
 80     static final Class<?>[] CLASSES = new Class<?>[] {
 81             Arena.class,
 82             MemorySegment.class,
 83             MemoryLayout.class,
 84             MemoryLayout.PathElement.class,
 85             SequenceLayout.class,
 86             ValueLayout.class,
 87             ValueLayout.OfBoolean.class,
 88             ValueLayout.OfByte.class,
 89             ValueLayout.OfChar.class,
 90             ValueLayout.OfShort.class,
 91             ValueLayout.OfInt.class,
 92             ValueLayout.OfFloat.class,
 93             ValueLayout.OfLong.class,
 94             ValueLayout.OfDouble.class,
 95             AddressLayout.class,
 96             PaddingLayout.class,
 97             GroupLayout.class,
 98             StructLayout.class,
 99             UnionLayout.class,
100             Linker.class,
101             Linker.Option.class,
102             FunctionDescriptor.class,
103             SegmentAllocator.class,
104             MemorySegment.Scope.class,
105             SymbolLookup.class
106     };
107 
108     static final Set<String> EXCLUDE_LIST = Set.of(
109             "java.lang.foreign.MemorySegment/reinterpret(java.lang.foreign.Arena,java.util.function.Consumer)/1/0",
110             "java.lang.foreign.MemorySegment/reinterpret(long,java.lang.foreign.Arena,java.util.function.Consumer)/2/0"
111     );
112 
113     static final Set<String> OBJECT_METHODS = Stream.of(Object.class.getMethods())
114             .map(Method::getName)
115             .collect(Collectors.toSet());
116 
117     static final Map<Class<?>, Object> DEFAULT_VALUES = new HashMap<>();
118 
119     static <Z> void addDefaultMapping(Class<Z> carrier, Z value) {
120         DEFAULT_VALUES.put(carrier, value);
121     }
122 
123     static {
124         addDefaultMapping(char.class, (char)0);
125         addDefaultMapping(byte.class, (byte)0);
126         addDefaultMapping(short.class, (short)0);
127         addDefaultMapping(int.class, 0);
128         addDefaultMapping(float.class, 0f);
129         addDefaultMapping(long.class, 0L);
130         addDefaultMapping(double.class, 0d);
131         addDefaultMapping(boolean.class, true);
132         addDefaultMapping(ByteOrder.class, ByteOrder.nativeOrder());
133         addDefaultMapping(Thread.class, Thread.currentThread());
134         addDefaultMapping(Cleaner.class, CleanerFactory.cleaner());
135         addDefaultMapping(Buffer.class, ByteBuffer.wrap(new byte[10]));
136         addDefaultMapping(ByteBuffer.class, ByteBuffer.wrap(new byte[10]));
137         addDefaultMapping(Path.class, Path.of("nonExistent"));
138         addDefaultMapping(FileChannel.MapMode.class, FileChannel.MapMode.PRIVATE);
139         addDefaultMapping(UnaryOperator.class, UnaryOperator.identity());
140         addDefaultMapping(String.class, "Hello!");
141         addDefaultMapping(Constable.class, "Hello!");
142         addDefaultMapping(Class.class, String.class);
143         addDefaultMapping(Runnable.class, () -> {});
144         addDefaultMapping(Object.class, new Object());
145         addDefaultMapping(VarHandle.class, JAVA_INT.varHandle());
146         addDefaultMapping(MethodHandle.class, MethodHandles.identity(int.class));
147         addDefaultMapping(List.class, List.of());
148         addDefaultMapping(Charset.class, Charset.defaultCharset());
149         addDefaultMapping(Consumer.class, x -> {});
150         addDefaultMapping(MethodType.class, MethodType.methodType(void.class));
151         addDefaultMapping(MemoryLayout.class, ValueLayout.JAVA_INT);
152         addDefaultMapping(ValueLayout.class, ValueLayout.JAVA_INT);
153         addDefaultMapping(AddressLayout.class, ValueLayout.ADDRESS);
154         addDefaultMapping(ValueLayout.OfByte.class, ValueLayout.JAVA_BYTE);
155         addDefaultMapping(ValueLayout.OfBoolean.class, ValueLayout.JAVA_BOOLEAN);
156         addDefaultMapping(ValueLayout.OfChar.class, ValueLayout.JAVA_CHAR);
157         addDefaultMapping(ValueLayout.OfShort.class, ValueLayout.JAVA_SHORT);
158         addDefaultMapping(ValueLayout.OfInt.class, ValueLayout.JAVA_INT);
159         addDefaultMapping(ValueLayout.OfFloat.class, ValueLayout.JAVA_FLOAT);
160         addDefaultMapping(ValueLayout.OfLong.class, JAVA_LONG);
161         addDefaultMapping(ValueLayout.OfDouble.class, ValueLayout.JAVA_DOUBLE);
162         addDefaultMapping(PaddingLayout.class, MemoryLayout.paddingLayout(4));
163         addDefaultMapping(GroupLayout.class, MemoryLayout.structLayout(ValueLayout.JAVA_INT));
164         addDefaultMapping(StructLayout.class, MemoryLayout.structLayout(ValueLayout.JAVA_INT));
165         addDefaultMapping(UnionLayout.class, MemoryLayout.unionLayout(ValueLayout.JAVA_INT));
166         addDefaultMapping(SequenceLayout.class, MemoryLayout.sequenceLayout(1, ValueLayout.JAVA_INT));
167         addDefaultMapping(SymbolLookup.class, SymbolLookup.loaderLookup());
168         addDefaultMapping(MemorySegment.class, MemorySegment.ofArray(new byte[10]));
169         addDefaultMapping(FunctionDescriptor.class, FunctionDescriptor.ofVoid());
170         addDefaultMapping(Linker.class, Linker.nativeLinker());
171         addDefaultMapping(Arena.class, Arena.ofConfined());
172         addDefaultMapping(MemorySegment.Scope.class, Arena.ofAuto().scope());
173         addDefaultMapping(SegmentAllocator.class, SegmentAllocator.prefixAllocator(MemorySegment.ofArray(new byte[10])));
174         addDefaultMapping(Supplier.class, () -> null);
175         addDefaultMapping(ClassLoader.class, TestNulls.class.getClassLoader());
176         addDefaultMapping(Thread.UncaughtExceptionHandler.class, (thread, ex) -> {});
177     }
178 
179     static final Map<Class<?>, Object[]> REPLACEMENT_VALUES = new HashMap<>();
180 
181     @SafeVarargs
182     static <Z> void addReplacements(Class<Z> carrier, Z... value) {
183         REPLACEMENT_VALUES.put(carrier, value);
184     }
185 
186     static {
187         addReplacements(Collection.class, null, Stream.of(new Object[] { null }).collect(Collectors.toList()));
188         addReplacements(List.class, null, Stream.of(new Object[] { null }).collect(Collectors.toList()));
189         addReplacements(Set.class, null, Stream.of(new Object[] { null }).collect(Collectors.toSet()));
190     }
191 
192     @Test(dataProvider = "cases")
193     public void testNulls(String testName, @NoInjection Method meth, Object receiver, Object[] args) {
194         try {
195             meth.invoke(receiver, args);
196             fail("Method invocation completed normally");
197         } catch (InvocationTargetException ex) {
198             Class<?> cause = ex.getCause().getClass();
199             assertEquals(cause, NullPointerException.class, "got " + cause.getName() + " - expected NullPointerException");
200         } catch (Throwable ex) {
201             fail("Unexpected exception: " + ex);
202         }
203     }
204 
205     @DataProvider(name = "cases")
206     static Iterator<Object[]> cases() {
207         List<Object[]> cases = new ArrayList<>();
208         for (Class<?> clazz : CLASSES) {
209             for (Method m : clazz.getMethods()) {
210                 if (OBJECT_METHODS.contains(m.getName())) continue;
211                 boolean isStatic = (m.getModifiers() & Modifier.STATIC) != 0;
212                 List<Integer> refIndices = new ArrayList<>();
213                 for (int i = 0; i < m.getParameterCount(); i++) {
214                     Class<?> param = m.getParameterTypes()[i];
215                     if (!param.isPrimitive()) {
216                         refIndices.add(i);
217                     }
218                 }
219                 for (int i : refIndices) {
220                     Object[] replacements = replacements(m.getParameterTypes()[i]);
221                     for (int r = 0 ; r < replacements.length ; r++) {
222                         String testName = clazz.getName() + "/" + shortSig(m) + "/" + i + "/" + r;
223                         if (EXCLUDE_LIST.contains(testName)) continue;
224                         Object[] args = new Object[m.getParameterCount()];
225                         for (int j = 0; j < args.length; j++) {
226                             args[j] = defaultValue(m.getParameterTypes()[j]);
227                         }
228                         args[i] = replacements[r];
229                         Object receiver = isStatic ? null : defaultValue(clazz);
230                         cases.add(new Object[]{testName, m, receiver, args});
231                     }
232                 }
233             }
234         }
235         return cases.iterator();
236     };
237 
238     static String shortSig(Method m) {
239         StringJoiner sj = new StringJoiner(",", m.getName() + "(", ")");
240         for (Class<?> parameterType : m.getParameterTypes()) {
241             sj.add(parameterType.getTypeName());
242         }
243         return sj.toString();
244     }
245 
246     static Object defaultValue(Class<?> carrier) {
247         if (carrier.isArray()) {
248             return Array.newInstance(carrier.componentType(), 0);
249         }
250         Object value = DEFAULT_VALUES.get(carrier);
251         if (value == null) {
252             throw new UnsupportedOperationException(carrier.getName());
253         }
254         return value;
255     }
256 
257     static Object[] replacements(Class<?> carrier) {
258         if (carrier.isArray() && !carrier.getComponentType().isPrimitive()) {
259             Object arr = Array.newInstance(carrier.componentType(), 1);
260             Array.set(arr, 0, null);
261             return new Object[] { null, arr };
262         }
263         return REPLACEMENT_VALUES.getOrDefault(carrier, new Object[] { null });
264     }
265 }