1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments;
 26 
 27 
 28 import hat.ifacemapper.Schema;
 29 import hat.buffer.Buffer;
 30 import hat.ifacemapper.MappableIface;
 31 
 32 import java.lang.constant.ClassDesc;
 33 import java.lang.foreign.*;
 34 import java.lang.invoke.MethodHandles;
 35 import java.lang.reflect.Field;
 36 import java.lang.reflect.Method;
 37 import jdk.incubator.code.*;
 38 import jdk.incubator.code.analysis.SSA;
 39 import jdk.incubator.code.op.CoreOp;
 40 import jdk.incubator.code.op.ExternalizableOp;
 41 import jdk.incubator.code.op.OpFactory;
 42 import jdk.incubator.code.type.FunctionType;
 43 import jdk.incubator.code.type.JavaType;
 44 import jdk.incubator.code.type.PrimitiveType;
 45 import jdk.incubator.code.CodeReflection;
 46 import java.util.*;
 47 import java.util.stream.Stream;
 48 
 49 public class LayoutExample {
 50 
 51     /*
 52        struct {
 53           StructTwo struct;
 54           int i;
 55        }
 56      */
 57 
 58         public interface Outer extends Buffer {
 59             interface Inner extends Struct {
 60                 int i();
 61 
 62                 void i(int v);
 63 
 64                 float f();
 65 
 66                 void f(float v);
 67 
 68               //  Schema schema = Schema.of(Inner.class, b->b.primitive("i").primitive("f"));
 69             }
 70 
 71             Inner right();
 72             Inner left();
 73             int i();
 74             void i(int v);
 75 
 76 
 77             Schema schema = Schema.of(Outer.class, b->b
 78                             .struct("left", left->left
 79                                     .field("i")
 80                                     .field("f")
 81                             )
 82                            // .struct("right", Inner.schema)
 83                             .field("i")
 84             );
 85         }
 86 
 87 
 88     @CodeReflection
 89     static float m(Outer s1) {
 90         // StructOne* s1
 91         // s1 -> i
 92         int i = s1.i();
 93         // s1 -> *s2
 94         Outer.Inner s2 = s1.left();
 95         // s2 -> i
 96         i += s2.i();
 97         // s2 -> f
 98         float f = s2.f();
 99         return i + f;
100     }
101 
102 
103     public static void main(String[] args) {
104         Optional<Method> om = Stream.of(LayoutExample.class.getDeclaredMethods())
105                 .filter(m -> m.getName().equals("m"))
106                 .findFirst();
107 
108         Method m = om.orElseThrow();
109         CoreOp.FuncOp f= Op.ofMethod(m).orElseThrow();
110         f = SSA.transform(f);
111         System.out.println(f.toText());
112         FunctionType functionType = transformStructClassToPtr(MethodHandles.lookup(), f);
113         System.out.println(f.toText());
114         CoreOp.FuncOp pm = transformInvokesToPtrs(MethodHandles.lookup(), f, functionType);
115         System.out.println(pm.toText());
116     }
117     static FunctionType transformStructClassToPtr(MethodHandles.Lookup l,
118                                                 CoreOp.FuncOp f) {
119         List<TypeElement> pTypes = new ArrayList<>();
120         for (Block.Parameter p : f.parameters()) {
121             pTypes.add(transformStructClassToPtr(l, p.type()));
122         }
123         return FunctionType.functionType(
124                 transformStructClassToPtr(l, f.invokableType().returnType()), pTypes);
125     }
126 
127     static CoreOp.FuncOp transformInvokesToPtrs(MethodHandles.Lookup l,
128                                                 CoreOp.FuncOp f, FunctionType functionType) {
129 
130         var builder= CoreOp.func(f.funcName(), functionType);
131 
132         var funcOp = builder.body(funcBlock -> {
133             funcBlock.transformBody(f.body(), funcBlock.parameters(), (b, op) -> {
134                 if (op instanceof CoreOp.InvokeOp invokeOp
135                         && invokeOp.hasReceiver()
136                         && invokeOp.operands().getFirst() instanceof Value receiver) {
137                     if (bufferOrBufferChildClass(l, receiver.type()) != null) {
138                         Value ptr = b.context().getValue(receiver);
139                         PtrToMember ptrToMemberOp = new PtrToMember(ptr, invokeOp.invokeDescriptor().name());
140                         Op.Result memberPtr = b.op(ptrToMemberOp);
141 
142                         if (invokeOp.operands().size() == 1) {
143                             // Pointer access and (possibly) value load
144                             if (ptrToMemberOp.resultType().layout() instanceof ValueLayout) {
145                                 Op.Result v = b.op(new PtrLoadValue(memberPtr));
146                                 b.context().mapValue(invokeOp.result(), v);
147                             } else {
148                                 b.context().mapValue(invokeOp.result(), memberPtr);
149                             }
150                         } else {
151                             // @@@
152                             // Value store
153                             throw new UnsupportedOperationException();
154                         }
155                     } else {
156                         b.op(op);
157                     }
158                 } else {
159                     b.op(op);
160                 }
161                 return b;
162             });
163         });
164         return funcOp;
165     }
166 
167 
168 
169     static boolean isBufferOrBufferChild(Class<?> maybeIface) {
170         return  maybeIface.isInterface() && (
171                 MappableIface.class.isAssignableFrom(maybeIface)
172         );
173 
174     }
175     static Schema bufferOrBufferChildSchema(MethodHandles.Lookup l, Class<?> maybeBufferOrBufferChild) {
176         if (isBufferOrBufferChild(maybeBufferOrBufferChild)) {
177             throw new IllegalArgumentException();
178         }
179         Field schemaField;
180         try {
181             schemaField = maybeBufferOrBufferChild.getField("schema");
182            return  (Schema)schemaField.get(null);
183         } catch (NoSuchFieldException | IllegalAccessException e) {
184             throw new RuntimeException(e);
185         }
186     }
187     static Class<?> bufferOrBufferChildClass(MethodHandles.Lookup l, TypeElement t) {
188         try {
189             if (!(t instanceof JavaType jt) || !(jt.resolve(l) instanceof Class<?> c)) {
190                 return null;
191             }
192             return isBufferOrBufferChild(c) ? c : null;
193         } catch (ReflectiveOperationException e) {
194             throw new RuntimeException(e);
195         }
196     }
197     static TypeElement transformStructClassToPtr(MethodHandles.Lookup l, TypeElement type) {
198         if (bufferOrBufferChildClass(l, type) instanceof Class<?> sc) {
199             return new PtrType(bufferOrBufferChildSchema(l, sc));
200         } else {
201             return type;
202         }
203     }
204 
205     public static final class PtrType implements TypeElement {
206         static final String NAME = "ptr";
207         MemoryLayout layout;
208         Schema schema;
209         final JavaType returnType;
210 
211         public PtrType(MemoryLayout layout) {
212             this.layout = layout;
213             this.returnType = switch (layout) {
214                 case StructLayout _ -> JavaType.type(ClassDesc.of(layout.name().orElseThrow()));
215                 case AddressLayout _ -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
216                 case ValueLayout valueLayout -> JavaType.type(valueLayout.carrier());
217                 default -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
218             };
219         }
220         public PtrType(Schema schema) {
221             this.schema = schema;
222             this.layout= null;//schema.layout();
223             this.returnType = switch (layout) {
224                 case StructLayout _ -> JavaType.type(ClassDesc.of(layout.name().orElseThrow()));
225                 case AddressLayout _ -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
226                 case ValueLayout valueLayout -> JavaType.type(valueLayout.carrier());
227                 default -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
228             };
229         }
230 
231         public JavaType returnType() {
232             return returnType;
233         }
234 
235         public MemoryLayout layout() {
236             return layout;
237         }
238         public Schema schema() {
239             return schema;
240         }
241 
242         @Override
243         public boolean equals(Object o) {
244             if (this == o) return true;
245             if (o == null || getClass() != o.getClass()) return false;
246             PtrType ptrType = (PtrType) o;
247             return Objects.equals(layout, ptrType.layout);
248         }
249 
250         @Override
251         public int hashCode() {
252             return Objects.hash(layout);
253         }
254 
255         @Override
256         public ExternalizedTypeElement externalize() {
257             return new ExternalizedTypeElement(NAME, List.of(returnType.externalize()));
258         }
259 
260         @Override
261         public String toString() {
262             return externalize().toString();
263         }
264     }
265 
266     @OpFactory.OpDeclaration(PtrToMember.NAME)
267     public static final class PtrToMember extends ExternalizableOp {
268         public static final String NAME = "ptr.to.member";
269         public static final String ATTRIBUTE_OFFSET = "offset";
270         public static final String ATTRIBUTE_NAME = "name";
271 
272         final String simpleMemberName;
273         final long memberOffset;
274         final PtrType resultType;
275 
276         PtrToMember(PtrToMember that, CopyContext cc) {
277             super(that, cc);
278             this.simpleMemberName = that.simpleMemberName;
279             this.memberOffset = that.memberOffset;
280             this.resultType = that.resultType;
281         }
282 
283         @Override
284         public PtrToMember transform(CopyContext cc, OpTransformer ot) {
285             return new PtrToMember(this, cc);
286         }
287 
288         public PtrToMember(Value ptr, String simpleMemberName) {
289             super(NAME, List.of(ptr));
290             this.simpleMemberName = simpleMemberName;
291 
292             if (!(ptr.type() instanceof PtrType ptrType)) {
293                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
294             }
295             // @@@ Support group layout
296             if (!(ptrType.layout() instanceof StructLayout structLayout)) {
297                 throw new IllegalArgumentException("Pointer type layout is not a struct layout: " + ptrType.layout());
298             }
299 
300             // Find the actual member name from the simple member name
301             String memberName = findMemberName(structLayout, simpleMemberName);
302             MemoryLayout.PathElement p = MemoryLayout.PathElement.groupElement(memberName);
303             this.memberOffset = structLayout.byteOffset(p);
304             MemoryLayout memberLayout = structLayout.select(p);
305             // Remove any simple member name from the layout
306             MemoryLayout ptrLayout = memberLayout instanceof StructLayout
307                     ? memberLayout.withName(className(memberName))
308                     : memberLayout.withoutName();
309             this.resultType = new PtrType(ptrLayout);
310         }
311 
312         // @@@ Change to return member index
313         static String findMemberName(StructLayout sl, String simpleMemberName) {
314             for (MemoryLayout layout : sl.memberLayouts()) {
315                 String memberName = layout.name().orElseThrow();
316                 if (simpleMemberName(memberName).equals(simpleMemberName)) {
317                     return memberName;
318                 }
319             }
320             throw new NoSuchElementException("No member found: " + simpleMemberName + " " + sl);
321         }
322 
323         static String simpleMemberName(String memberName) {
324             int i = memberName.indexOf("::");
325             return i != -1
326                     ? memberName.substring(i + 2)
327                     : memberName;
328         }
329 
330         static String className(String memberName) {
331             int i = memberName.indexOf("::");
332             return i != -1
333                     ? memberName.substring(0, i)
334                     : null;
335         }
336 
337         @Override
338         public PtrType resultType() {
339             return resultType;
340         }
341 
342         @Override
343         public Map<String, Object> attributes() {
344             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
345             attrs.put("", simpleMemberName);
346             attrs.put(ATTRIBUTE_OFFSET, memberOffset);
347             return attrs;
348         }
349 
350         public String simpleMemberName() {
351             return simpleMemberName;
352         }
353 
354         public long memberOffset() {
355             return memberOffset;
356         }
357 
358         public Value ptrValue() {
359             return operands().get(0);
360         }
361     }
362 
363 
364     @OpFactory.OpDeclaration(PtrToMember.NAME)
365     public static final class PtrAddOffset extends Op {
366         public static final String NAME = "ptr.add.offset";
367 
368         PtrAddOffset(PtrAddOffset that, CopyContext cc) {
369             super(that, cc);
370         }
371 
372         @Override
373         public PtrAddOffset transform(CopyContext cc, OpTransformer ot) {
374             return new PtrAddOffset(this, cc);
375         }
376 
377         public PtrAddOffset(Value ptr, Value offset) {
378             super(NAME, List.of(ptr, offset));
379 
380             if (!(ptr.type() instanceof PtrType)) {
381                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
382             }
383             if (!(offset.type() instanceof PrimitiveType pt && pt.equals(JavaType.LONG))) {
384                 throw new IllegalArgumentException("Offset value is not of primitve long type: " + offset.type());
385             }
386         }
387 
388         @Override
389         public TypeElement resultType() {
390             return ptrValue().type();
391         }
392 
393         public Value ptrValue() {
394             return operands().get(0);
395         }
396 
397         public Value offsetValue() {
398             return operands().get(1);
399         }
400     }
401 
402     @OpFactory.OpDeclaration(PtrToMember.NAME)
403     public static final class PtrLoadValue extends Op {
404         public static final String NAME = "ptr.load.value";
405 
406         final JavaType resultType;
407 
408         PtrLoadValue(PtrLoadValue that, CopyContext cc) {
409             super(that, cc);
410             this.resultType = that.resultType;
411         }
412 
413         @Override
414         public PtrLoadValue transform(CopyContext cc, OpTransformer ot) {
415             return new PtrLoadValue(this, cc);
416         }
417 
418         public PtrLoadValue(Value ptr) {
419             super(NAME, List.of(ptr));
420 
421             if (!(ptr.type() instanceof PtrType ptrType)) {
422                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
423             }
424             if (!(ptrType.layout() instanceof ValueLayout)) {
425                 throw new IllegalArgumentException("Pointer type layout is not a value layout: " + ptrType.layout());
426             }
427             this.resultType = ptrType.returnType();
428         }
429 
430         @Override
431         public TypeElement resultType() {
432             return resultType;
433         }
434 
435         public Value ptrValue() {
436             return operands().get(0);
437         }
438     }
439 
440     @OpFactory.OpDeclaration(PtrToMember.NAME)
441     public static final class PtrStoreValue extends Op {
442         public static final String NAME = "ptr.store.value";
443 
444         PtrStoreValue(PtrStoreValue that, CopyContext cc) {
445             super(that, cc);
446         }
447 
448         @Override
449         public PtrStoreValue transform(CopyContext cc, OpTransformer ot) {
450             return new PtrStoreValue(this, cc);
451         }
452 
453         public PtrStoreValue(Value ptr, Value v) {
454             super(NAME, List.of(ptr));
455 
456             if (!(ptr.type() instanceof PtrType ptrType)) {
457                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
458             }
459             if (!(ptrType.layout() instanceof ValueLayout)) {
460                 throw new IllegalArgumentException("Pointer type layout is not a value layout: " + ptrType.layout());
461             }
462             if (!(ptrType.returnType().equals(v.type()))) {
463                 throw new IllegalArgumentException("Pointer reference type is not same as value to store type: "
464                         + ptrType.returnType() + " " + v.type());
465             }
466         }
467 
468         @Override
469         public TypeElement resultType() {
470             return JavaType.VOID;
471         }
472 
473         public Value ptrValue() {
474             return operands().get(0);
475         }
476     }
477 }
478