1 /*
  2  * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @library ../ /test/lib
 28  * @requires jdk.foreign.linker != "UNSUPPORTED"
 29  * @run testng/othervm --enable-native-access=ALL-UNNAMED TestCaptureCallState
 30  */
 31 
 32 import org.testng.annotations.DataProvider;
 33 import org.testng.annotations.Test;
 34 
 35 import java.lang.foreign.Arena;
 36 import java.lang.foreign.FunctionDescriptor;
 37 import java.lang.foreign.Linker;
 38 import java.lang.foreign.MemoryLayout;
 39 import java.lang.foreign.MemorySegment;
 40 import java.lang.foreign.StructLayout;
 41 import java.lang.invoke.MethodHandle;
 42 import java.lang.invoke.VarHandle;
 43 import java.util.ArrayList;
 44 import java.util.List;
 45 import java.util.Map;
 46 import java.util.function.Consumer;
 47 
 48 import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
 49 import static java.lang.foreign.ValueLayout.JAVA_DOUBLE;
 50 import static java.lang.foreign.ValueLayout.JAVA_INT;
 51 import static java.lang.foreign.ValueLayout.JAVA_LONG;
 52 import static org.testng.Assert.assertEquals;
 53 import static org.testng.Assert.assertTrue;
 54 
 55 public class TestCaptureCallState extends NativeTestHelper {
 56 
 57     static {
 58         System.loadLibrary("CaptureCallState");
 59         if (IS_WINDOWS) {
 60             String system32 = System.getenv("SystemRoot") + "\\system32";
 61             System.load(system32 + "\\Kernel32.dll");
 62             System.load(system32 + "\\Ws2_32.dll");
 63         }
 64     }
 65 
 66     private record SaveValuesCase(String nativeTarget, FunctionDescriptor nativeDesc, boolean trivial, String threadLocalName, Consumer<Object> resultCheck) {}
 67 
 68     @Test(dataProvider = "cases")
 69     public void testSavedThreadLocal(SaveValuesCase testCase) throws Throwable {
 70         List<Linker.Option> options = new ArrayList<>();
 71         options.add(Linker.Option.captureCallState(testCase.threadLocalName()));
 72         if (testCase.trivial()) {
 73             options.add(Linker.Option.isTrivial());
 74         }
 75         MethodHandle handle = downcallHandle(testCase.nativeTarget(), testCase.nativeDesc(), options.toArray(Linker.Option[]::new));
 76 
 77         StructLayout capturedStateLayout = Linker.Option.captureStateLayout();
 78         VarHandle errnoHandle = capturedStateLayout.varHandle(groupElement(testCase.threadLocalName()));
 79 
 80         try (Arena arena = Arena.ofConfined()) {
 81             MemorySegment saveSeg = arena.allocate(capturedStateLayout);
 82             int testValue = 42;
 83             boolean needsAllocator = testCase.nativeDesc().returnLayout().map(StructLayout.class::isInstance).orElse(false);
 84             Object result = needsAllocator
 85                 ? handle.invoke(arena, saveSeg, testValue)
 86                 : handle.invoke(saveSeg, testValue);
 87             testCase.resultCheck().accept(result);
 88             int savedErrno = (int) errnoHandle.get(saveSeg);
 89             assertEquals(savedErrno, testValue);
 90         }
 91     }
 92 
 93     @Test(dataProvider = "invalidCaptureSegmentCases")
 94     public void testInvalidCaptureSegment(MemorySegment captureSegment,
 95                                           Class<?> expectedExceptionType, String expectedExceptionMessage) {
 96         Linker.Option stl = Linker.Option.captureCallState("errno");
 97         MethodHandle handle = downcallHandle("set_errno_V", FunctionDescriptor.ofVoid(C_INT), stl);
 98 
 99         try {
100             int testValue = 42;
101             handle.invoke(captureSegment, testValue); // should throw
102         } catch (Throwable t) {
103             assertTrue(expectedExceptionType.isInstance(t));
104             assertTrue(t.getMessage().matches(expectedExceptionMessage));
105         }
106     }
107 
108     interface CaseAdder {
109       void addCase(String nativeTarget, FunctionDescriptor nativeDesc, String threadLocalName, Consumer<Object> resultCheck);
110     }
111 
112     @DataProvider
113     public static Object[][] cases() {
114         List<SaveValuesCase> cases = new ArrayList<>();
115         CaseAdder adder = (nativeTarget, nativeDesc, threadLocalName, resultCheck) -> {
116           cases.add(new SaveValuesCase(nativeTarget, nativeDesc, false, threadLocalName, resultCheck));
117           cases.add(new SaveValuesCase(nativeTarget, nativeDesc, true, threadLocalName, resultCheck));
118         };
119 
120         adder.addCase("set_errno_V", FunctionDescriptor.ofVoid(JAVA_INT), "errno", o -> {});
121         adder.addCase("set_errno_I", FunctionDescriptor.of(JAVA_INT, JAVA_INT), "errno", o -> assertEquals((int) o, 42));
122         adder.addCase("set_errno_D", FunctionDescriptor.of(JAVA_DOUBLE, JAVA_INT), "errno", o -> assertEquals((double) o, 42.0));
123 
124         structCase(adder, "SL",  Map.of(JAVA_LONG.withName("x"), 42L));
125         structCase(adder, "SLL", Map.of(JAVA_LONG.withName("x"), 42L,
126                                          JAVA_LONG.withName("y"), 42L));
127         structCase(adder, "SLLL", Map.of(JAVA_LONG.withName("x"), 42L,
128                                          JAVA_LONG.withName("y"), 42L,
129                                          JAVA_LONG.withName("z"), 42L));
130         structCase(adder, "SD",  Map.of(JAVA_DOUBLE.withName("x"), 42D));
131         structCase(adder, "SDD", Map.of(JAVA_DOUBLE.withName("x"), 42D,
132                                          JAVA_DOUBLE.withName("y"), 42D));
133         structCase(adder, "SDDD", Map.of(JAVA_DOUBLE.withName("x"), 42D,
134                                          JAVA_DOUBLE.withName("y"), 42D,
135                                          JAVA_DOUBLE.withName("z"), 42D));
136 
137         if (IS_WINDOWS) {
138             adder.addCase("SetLastError", FunctionDescriptor.ofVoid(JAVA_INT), "GetLastError", o -> {});
139             adder.addCase("WSASetLastError", FunctionDescriptor.ofVoid(JAVA_INT), "WSAGetLastError", o -> {});
140         }
141 
142         return cases.stream().map(tc -> new Object[] {tc}).toArray(Object[][]::new);
143     }
144 
145     static void structCase(CaseAdder adder, String name, Map<MemoryLayout, Object> fields) {
146         StructLayout layout = MemoryLayout.structLayout(fields.keySet().toArray(MemoryLayout[]::new));
147 
148         Consumer<Object> check = o -> {};
149         for (var field : fields.entrySet()) {
150             MemoryLayout fieldLayout = field.getKey();
151             VarHandle fieldHandle = layout.varHandle(MemoryLayout.PathElement.groupElement(fieldLayout.name().get()));
152             Object value = field.getValue();
153             check = check.andThen(o -> assertEquals(fieldHandle.get(o), value));
154         }
155 
156         adder.addCase("set_errno_" + name, FunctionDescriptor.of(layout, JAVA_INT), "errno", check);
157     }
158 
159     @DataProvider
160     public static Object[][] invalidCaptureSegmentCases() {
161         return new Object[][]{
162             {Arena.ofAuto().allocate(1), IndexOutOfBoundsException.class, ".*Out of bound access on segment.*"},
163             {MemorySegment.NULL, IllegalArgumentException.class, ".*Capture segment is NULL.*"},
164             {Arena.ofAuto().allocate(Linker.Option.captureStateLayout().byteSize() + 3).asSlice(3), // misaligned
165                     IllegalArgumentException.class, ".*Target offset incompatible with alignment constraints.*"},
166         };
167     }
168 }