1 /*
  2  * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 import jdk.incubator.foreign.CLinker;
 26 import jdk.incubator.foreign.FunctionDescriptor;
 27 import jdk.incubator.foreign.NativeSymbol;
 28 import jdk.incubator.foreign.SymbolLookup;
 29 import jdk.incubator.foreign.MemoryLayout;
 30 import jdk.incubator.foreign.MemorySegment;
 31 
 32 import jdk.incubator.foreign.ResourceScope;
 33 import org.testng.annotations.BeforeClass;
 34 
 35 import java.lang.invoke.MethodHandle;
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.invoke.MethodType;
 38 import java.util.List;
 39 import java.util.concurrent.atomic.AtomicReference;
 40 import java.util.function.Consumer;
 41 import java.util.stream.Collectors;
 42 import java.util.stream.Stream;
 43 
 44 import static java.lang.invoke.MethodHandles.insertArguments;
 45 import static org.testng.Assert.assertEquals;
 46 
 47 public abstract class TestUpcallBase extends CallGeneratorHelper {
 48 
 49     static CLinker ABI = CLinker.systemCLinker();
 50     static final SymbolLookup LOOKUP = SymbolLookup.loaderLookup();
 51 
 52     private static MethodHandle DUMMY;
 53     private static MethodHandle PASS_AND_SAVE;
 54 
 55     static {
 56         try {
 57             DUMMY = MethodHandles.lookup().findStatic(TestUpcallBase.class, "dummy", MethodType.methodType(void.class));
 58             PASS_AND_SAVE = MethodHandles.lookup().findStatic(TestUpcallBase.class, "passAndSave",
 59                     MethodType.methodType(Object.class, Object[].class, AtomicReference.class, int.class));
 60         } catch (Throwable ex) {
 61             throw new IllegalStateException(ex);
 62         }
 63     }
 64 
 65     private static NativeSymbol DUMMY_STUB;
 66 
 67     @BeforeClass
 68     void setup() {
 69         DUMMY_STUB = ABI.upcallStub(DUMMY, FunctionDescriptor.ofVoid(), ResourceScope.newImplicitScope());
 70     }
 71 
 72     static FunctionDescriptor function(Ret ret, List<ParamType> params, List<StructFieldType> fields) {
 73         return function(ret, params, fields, List.of());
 74     }
 75 
 76     static FunctionDescriptor function(Ret ret, List<ParamType> params, List<StructFieldType> fields, List<MemoryLayout> prefix) {
 77         List<MemoryLayout> paramLayouts = params.stream().map(p -> p.layout(fields)).collect(Collectors.toList());
 78         paramLayouts.add(C_POINTER); // the callback
 79         MemoryLayout[] layouts = Stream.concat(prefix.stream(), paramLayouts.stream()).toArray(MemoryLayout[]::new);
 80         return ret == Ret.VOID ?
 81                 FunctionDescriptor.ofVoid(layouts) :
 82                 FunctionDescriptor.of(layouts[prefix.size()], layouts);
 83     }
 84 
 85     static Object[] makeArgs(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks) throws ReflectiveOperationException {
 86         return makeArgs(scope, ret, params, fields, checks, argChecks, List.of());
 87     }
 88 
 89     static Object[] makeArgs(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks, List<MemoryLayout> prefix) throws ReflectiveOperationException {
 90         Object[] args = new Object[prefix.size() + params.size() + 1];
 91         int argNum = 0;
 92         for (MemoryLayout layout : prefix) {
 93             args[argNum++] = makeArg(layout, null, false);
 94         }
 95         for (int i = 0 ; i < params.size() ; i++) {
 96             args[argNum++] = makeArg(params.get(i).layout(fields), checks, i == 0);
 97         }
 98         args[argNum] = makeCallback(scope, ret, params, fields, checks, argChecks, prefix);
 99         return args;
100     }
101 
102     static NativeSymbol makeCallback(ResourceScope scope, Ret ret, List<ParamType> params, List<StructFieldType> fields, List<Consumer<Object>> checks, List<Consumer<Object[]>> argChecks, List<MemoryLayout> prefix) {
103         if (params.isEmpty()) {
104             return DUMMY_STUB;
105         }
106 
107         AtomicReference<Object[]> box = new AtomicReference<>();
108         MethodHandle mh = insertArguments(PASS_AND_SAVE, 1, box, prefix.size());
109         mh = mh.asCollector(Object[].class, prefix.size() + params.size());
110 
111         for(int i = 0; i < prefix.size(); i++) {
112             mh = mh.asType(mh.type().changeParameterType(i, carrier(prefix.get(i), false)));
113         }
114 
115         for (int i = 0; i < params.size(); i++) {
116             ParamType pt = params.get(i);
117             MemoryLayout layout = pt.layout(fields);
118             Class<?> carrier = carrier(layout, false);
119             mh = mh.asType(mh.type().changeParameterType(prefix.size() + i, carrier));
120 
121             final int finalI = prefix.size() + i;
122             if (carrier == MemorySegment.class) {
123                 argChecks.add(o -> assertStructEquals((MemorySegment) box.get()[finalI], (MemorySegment) o[finalI], layout));
124             } else {
125                 argChecks.add(o -> assertEquals(box.get()[finalI], o[finalI]));
126             }
127         }
128 
129         ParamType firstParam = params.get(0);
130         MemoryLayout firstlayout = firstParam.layout(fields);
131         Class<?> firstCarrier = carrier(firstlayout, true);
132 
133         if (firstCarrier == MemorySegment.class) {
134             checks.add(o -> assertStructEquals((MemorySegment) box.get()[prefix.size()], (MemorySegment) o, firstlayout));
135         } else {
136             checks.add(o -> assertEquals(o, box.get()[prefix.size()]));
137         }
138 
139         mh = mh.asType(mh.type().changeReturnType(ret == Ret.VOID ? void.class : firstCarrier));
140 
141         MemoryLayout[] paramLayouts = Stream.concat(prefix.stream(), params.stream().map(p -> p.layout(fields))).toArray(MemoryLayout[]::new);
142         FunctionDescriptor func = ret != Ret.VOID
143                 ? FunctionDescriptor.of(firstlayout, paramLayouts)
144                 : FunctionDescriptor.ofVoid(paramLayouts);
145         return ABI.upcallStub(mh, func, scope);
146     }
147 
148     static Object passAndSave(Object[] o, AtomicReference<Object[]> ref, int retArg) {
149         for (int i = 0; i < o.length; i++) {
150             if (o[i] instanceof MemorySegment) {
151                 MemorySegment ms = (MemorySegment) o[i];
152                 MemorySegment copy = MemorySegment.allocateNative(ms.byteSize(), ResourceScope.newImplicitScope());
153                 copy.copyFrom(ms);
154                 o[i] = copy;
155             }
156         }
157         ref.set(o);
158         return o[retArg];
159     }
160 
161     static void dummy() {
162         //do nothing
163     }
164 }