1 /*
  2  * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26  * @run testng/othervm --enable-native-access=ALL-UNNAMED SafeFunctionAccessTest
 27  */
 28 
 29 import java.lang.foreign.Arena;
 30 import java.lang.foreign.Linker;
 31 import java.lang.foreign.FunctionDescriptor;
 32 import java.lang.foreign.MemorySegment;
 33 import java.lang.foreign.MemoryLayout;
 34 
 35 import java.lang.invoke.MethodHandle;
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.invoke.MethodType;
 38 
 39 import java.util.stream.Stream;
 40 
 41 import org.testng.annotations.*;
 42 
 43 import static org.testng.Assert.*;
 44 
 45 public class SafeFunctionAccessTest extends NativeTestHelper {
 46     static {
 47         System.loadLibrary("SafeAccess");
 48     }
 49 
 50     static MemoryLayout POINT = MemoryLayout.structLayout(
 51             C_INT, C_INT
 52     );
 53 
 54     @Test(expectedExceptions = IllegalStateException.class)
 55     public void testClosedStruct() throws Throwable {
 56         MemorySegment segment;
 57         try (Arena arena = Arena.ofConfined()) {
 58             segment = arena.allocate(POINT);
 59         }
 60         assertFalse(segment.scope().isAlive());
 61         MethodHandle handle = Linker.nativeLinker().downcallHandle(
 62                 findNativeOrThrow("struct_func"),
 63                 FunctionDescriptor.ofVoid(POINT));
 64 
 65         handle.invokeExact(segment);
 66     }
 67 
 68     @Test
 69     public void testClosedStructAddr_6() throws Throwable {
 70         MethodHandle handle = Linker.nativeLinker().downcallHandle(
 71                 findNativeOrThrow("addr_func_6"),
 72                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER, C_POINTER, C_POINTER, C_POINTER, C_POINTER));
 73         record Allocation(Arena drop, MemorySegment segment) {
 74             static Allocation of(MemoryLayout layout) {
 75                 Arena arena = Arena.ofShared();
 76                 return new Allocation(arena, arena.allocate(layout));
 77             }
 78         }
 79         for (int i = 0 ; i < 6 ; i++) {
 80             Allocation[] allocations = new Allocation[]{
 81                     Allocation.of(POINT),
 82                     Allocation.of(POINT),
 83                     Allocation.of(POINT),
 84                     Allocation.of(POINT),
 85                     Allocation.of(POINT),
 86                     Allocation.of(POINT)
 87             };
 88             // check liveness
 89             allocations[i].drop().close();
 90             for (int j = 0 ; j < 6 ; j++) {
 91                 if (i == j) {
 92                     assertFalse(allocations[j].drop().scope().isAlive());
 93                 } else {
 94                     assertTrue(allocations[j].drop().scope().isAlive());
 95                 }
 96             }
 97             try {
 98                 handle.invokeWithArguments(Stream.of(allocations).map(Allocation::segment).toArray());
 99                 fail();
100             } catch (IllegalStateException ex) {
101                 assertTrue(ex.getMessage().contains("Already closed"));
102             }
103             for (int j = 0 ; j < 6 ; j++) {
104                 if (i != j) {
105                     allocations[j].drop().close(); // should succeed!
106                 }
107             }
108         }
109     }
110 
111     @Test(expectedExceptions = IllegalStateException.class)
112     public void testClosedUpcall() throws Throwable {
113         MemorySegment upcall;
114         try (Arena arena = Arena.ofConfined()) {
115             MethodHandle dummy = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "dummy", MethodType.methodType(void.class));
116             upcall = Linker.nativeLinker().upcallStub(dummy, FunctionDescriptor.ofVoid(), arena);
117         }
118         assertFalse(upcall.scope().isAlive());
119         MethodHandle handle = Linker.nativeLinker().downcallHandle(
120                 findNativeOrThrow("addr_func"),
121                 FunctionDescriptor.ofVoid(C_POINTER));
122 
123         handle.invokeExact(upcall);
124     }
125 
126     static void dummy() { }
127 
128     @Test
129     public void testClosedStructCallback() throws Throwable {
130         MethodHandle handle = Linker.nativeLinker().downcallHandle(
131                 findNativeOrThrow("addr_func_cb"),
132                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER));
133 
134         try (Arena arena = Arena.ofConfined()) {
135             MemorySegment segment = arena.allocate(POINT);
136             handle.invokeExact(segment, sessionChecker(arena));
137         }
138     }
139 
140     @Test
141     public void testClosedUpcallCallback() throws Throwable {
142         MethodHandle handle = Linker.nativeLinker().downcallHandle(
143                 findNativeOrThrow("addr_func_cb"),
144                 FunctionDescriptor.ofVoid(C_POINTER, C_POINTER));
145 
146         try (Arena arena = Arena.ofConfined()) {
147             MethodHandle dummy = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "dummy", MethodType.methodType(void.class));
148             MemorySegment upcall = Linker.nativeLinker().upcallStub(dummy, FunctionDescriptor.ofVoid(), arena);
149             handle.invokeExact(upcall, sessionChecker(arena));
150         }
151     }
152 
153     MemorySegment sessionChecker(Arena arena) {
154         try {
155             MethodHandle handle = MethodHandles.lookup().findStatic(SafeFunctionAccessTest.class, "checkSession",
156                     MethodType.methodType(void.class, Arena.class));
157             handle = handle.bindTo(arena);
158             return Linker.nativeLinker().upcallStub(handle, FunctionDescriptor.ofVoid(), Arena.ofAuto());
159         } catch (Throwable ex) {
160             throw new AssertionError(ex);
161         }
162     }
163 
164     static void checkSession(Arena arena) {
165         try {
166             arena.close();
167             fail("Session closed unexpectedly!");
168         } catch (IllegalStateException ex) {
169             assertTrue(ex.getMessage().contains("acquired")); //if acquired, fine
170         }
171     }
172 }