1 /*
  2  * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @enablePreview
 27  * @requires jdk.foreign.linker != "UNSUPPORTED"
 28  * @run testng/othervm --enable-native-access=ALL-UNNAMED SafeFunctionAccessTest
 29  */
 30 
 31 import java.lang.foreign.Arena;
 32 import java.lang.foreign.Linker;
 33 import java.lang.foreign.FunctionDescriptor;
 34 import java.lang.foreign.MemorySegment;
 35 import java.lang.foreign.MemoryLayout;
 36 
 37 import java.lang.invoke.MethodHandle;
 38 import java.lang.invoke.MethodHandles;
 39 import java.lang.invoke.MethodType;
 40 
 41 import java.util.stream.Stream;
 42 
 43 import org.testng.annotations.*;
 44 
 45 import static org.testng.Assert.*;
 46 
 47 public class SafeFunctionAccessTest extends NativeTestHelper {
 48     static {
 49         System.loadLibrary("SafeAccess");
 50     }
 51 
 52     static MemoryLayout POINT = MemoryLayout.structLayout(
 53             C_INT, C_INT
 54     );
 55 
 56     @Test(expectedExceptions = IllegalStateException.class)
 57     public void testClosedStruct() throws Throwable {
 58         MemorySegment segment;
 59         try (Arena arena = Arena.ofConfined()) {
 60             segment = arena.allocate(POINT);
 61         }
 62         assertFalse(segment.scope().isAlive());
 63         MethodHandle handle = Linker.nativeLinker().downcallHandle(
 64                 findNativeOrThrow("struct_func"),
 65                 FunctionDescriptor.ofVoid(POINT));
 66 
 67         handle.invokeExact(segment);
 68     }
 69 
 70     @Test
 71     public void testClosedStructAddr_6() throws Throwable {
 72         MethodHandle handle = Linker.nativeLinker().downcallHandle(
 73                 findNativeOrThrow("addr_func_6"),
 74                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER, C_POINTER, C_POINTER, C_POINTER, C_POINTER));
 75         record Allocation(Arena drop, MemorySegment segment) {
 76             static Allocation of(MemoryLayout layout) {
 77                 Arena arena = Arena.ofShared();
 78                 return new Allocation(arena, arena.allocate(layout));
 79             }
 80         }
 81         for (int i = 0 ; i < 6 ; i++) {
 82             Allocation[] allocations = new Allocation[]{
 83                     Allocation.of(POINT),
 84                     Allocation.of(POINT),
 85                     Allocation.of(POINT),
 86                     Allocation.of(POINT),
 87                     Allocation.of(POINT),
 88                     Allocation.of(POINT)
 89             };
 90             // check liveness
 91             allocations[i].drop().close();
 92             for (int j = 0 ; j < 6 ; j++) {
 93                 if (i == j) {
 94                     assertFalse(allocations[j].drop().scope().isAlive());
 95                 } else {
 96                     assertTrue(allocations[j].drop().scope().isAlive());
 97                 }
 98             }
 99             try {
100                 handle.invokeWithArguments(Stream.of(allocations).map(Allocation::segment).toArray());
101                 fail();
102             } catch (IllegalStateException ex) {
103                 assertTrue(ex.getMessage().contains("Already closed"));
104             }
105             for (int j = 0 ; j < 6 ; j++) {
106                 if (i != j) {
107                     allocations[j].drop().close(); // should succeed!
108                 }
109             }
110         }
111     }
112 
113     @Test(expectedExceptions = IllegalStateException.class)
114     public void testClosedUpcall() throws Throwable {
115         MemorySegment upcall;
116         try (Arena arena = Arena.ofConfined()) {
117             MethodHandle dummy = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "dummy", MethodType.methodType(void.class));
118             upcall = Linker.nativeLinker().upcallStub(dummy, FunctionDescriptor.ofVoid(), arena);
119         }
120         assertFalse(upcall.scope().isAlive());
121         MethodHandle handle = Linker.nativeLinker().downcallHandle(
122                 findNativeOrThrow("addr_func"),
123                 FunctionDescriptor.ofVoid(C_POINTER));
124 
125         handle.invokeExact(upcall);
126     }
127 
128     static void dummy() { }
129 
130     @Test
131     public void testClosedStructCallback() throws Throwable {
132         MethodHandle handle = Linker.nativeLinker().downcallHandle(
133                 findNativeOrThrow("addr_func_cb"),
134                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER));
135 
136         try (Arena arena = Arena.ofConfined()) {
137             MemorySegment segment = arena.allocate(POINT);
138             handle.invokeExact(segment, sessionChecker(arena));
139         }
140     }
141 
142     @Test
143     public void testClosedUpcallCallback() throws Throwable {
144         MethodHandle handle = Linker.nativeLinker().downcallHandle(
145                 findNativeOrThrow("addr_func_cb"),
146                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER));
147 
148         try (Arena arena = Arena.ofConfined()) {
149             MethodHandle dummy = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "dummy", MethodType.methodType(void.class));
150             MemorySegment upcall = Linker.nativeLinker().upcallStub(dummy, FunctionDescriptor.ofVoid(), arena);
151             handle.invokeExact(upcall, sessionChecker(arena));
152         }
153     }
154 
155     MemorySegment sessionChecker(Arena arena) {
156         try {
157             MethodHandle handle = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "checkSession",
158                     MethodType.methodType(void.class, Arena.class));
159             handle = handle.bindTo(arena);
160             return Linker.nativeLinker().upcallStub(handle, FunctionDescriptor.ofVoid(), Arena.ofAuto());
161         } catch (Throwable ex) {
162             throw new AssertionError(ex);
163         }
164     }
165 
166     static void checkSession(Arena arena) {
167         try {
168             arena.close();
169             fail("Session closed unexpectedly!");
170         } catch (IllegalStateException ex) {
171             assertTrue(ex.getMessage().contains("acquired")); //if acquired, fine
172         }
173     }
174 }