1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package experiments;
 26 
 27 
 28 import hat.ifacemapper.Schema;
 29 import hat.buffer.Buffer;
 30 import hat.ifacemapper.MappableIface;
 31 
 32 import java.lang.constant.ClassDesc;
 33 import java.lang.foreign.*;
 34 import java.lang.invoke.MethodHandles;
 35 import java.lang.reflect.Field;
 36 import java.lang.reflect.Method;
 37 import jdk.incubator.code.*;
 38 import jdk.incubator.code.analysis.SSA;
 39 import jdk.incubator.code.op.CoreOp;
 40 import jdk.incubator.code.op.ExternalizableOp;
 41 import jdk.incubator.code.op.OpFactory;
 42 import jdk.incubator.code.type.FunctionType;
 43 import jdk.incubator.code.type.JavaType;
 44 import jdk.incubator.code.type.PrimitiveType;
 45 import jdk.incubator.code.CodeReflection;
 46 import java.util.*;
 47 import java.util.stream.Stream;
 48 
 49 public class LayoutExample {
 50 
 51     /*
 52        struct {
 53           StructTwo struct;
 54           int i;
 55        }
 56      */
 57 
 58         public interface Outer extends Buffer {
 59             interface Inner extends Struct {
 60                 int i();
 61 
 62                 void i(int v);
 63 
 64                 float f();
 65 
 66                 void f(float v);
 67 
 68               //  Schema schema = Schema.of(Inner.class, b->b.primitive("i").primitive("f"));
 69             }
 70 
 71             Inner right();
 72             Inner left();
 73             int i();
 74             void i(int v);
 75 
 76 
 77             Schema schema = Schema.of(Outer.class, b->b
 78                             .struct("left", left->left
 79                                     .field("i")
 80                                     .field("f")
 81                             )
 82                            // .struct("right", Inner.schema)
 83                             .field("i")
 84             );
 85         }
 86 
 87 
 88     @CodeReflection
 89     static float m(Outer s1) {
 90         // StructOne* s1
 91         // s1 -> i
 92         int i = s1.i();
 93         // s1 -> *s2
 94         Outer.Inner s2 = s1.left();
 95         // s2 -> i
 96         i += s2.i();
 97         // s2 -> f
 98         float f = s2.f();
 99         return i + f;
100     }
101 
102 
103     public static void main(String[] args) {
104         var lookup =     MethodHandles.lookup();
105         Optional<Method> om = Stream.of(LayoutExample.class.getDeclaredMethods())
106                 .filter(m -> m.getName().equals("m"))
107                 .findFirst();
108 
109         Method m = om.orElseThrow();
110         CoreOp.FuncOp f= Op.ofMethod(m).orElseThrow();
111         f = SSA.transform(f);
112         System.out.println(f.toText());
113         FunctionType functionType = transformStructClassToPtr(lookup, f);
114         System.out.println(f.toText());
115         CoreOp.FuncOp pm = transformInvokesToPtrs(lookup, f, functionType);
116         System.out.println(pm.toText());
117     }
118     static FunctionType transformStructClassToPtr(MethodHandles.Lookup l,
119                                                 CoreOp.FuncOp f) {
120         List<TypeElement> pTypes = new ArrayList<>();
121         for (Block.Parameter p : f.parameters()) {
122             pTypes.add(transformStructClassToPtr(l, p.type()));
123         }
124         return FunctionType.functionType(
125                 transformStructClassToPtr(l, f.invokableType().returnType()), pTypes);
126     }
127 
128     static CoreOp.FuncOp transformInvokesToPtrs(MethodHandles.Lookup l,
129                                                 CoreOp.FuncOp f, FunctionType functionType) {
130 
131         var builder= CoreOp.func(f.funcName(), functionType);
132 
133         var funcOp = builder.body(funcBlock -> {
134             funcBlock.transformBody(f.body(), funcBlock.parameters(), (b, op) -> {
135                 if (op instanceof CoreOp.InvokeOp invokeOp
136                         && invokeOp.hasReceiver()
137                         && invokeOp.operands().getFirst() instanceof Value receiver) {
138                     if (bufferOrBufferChildClass(l, receiver.type()) != null) {
139                         Value ptr = b.context().getValue(receiver);
140                         PtrToMember ptrToMemberOp = new PtrToMember(ptr, invokeOp.invokeDescriptor().name());
141                         Op.Result memberPtr = b.op(ptrToMemberOp);
142 
143                         if (invokeOp.operands().size() == 1) {
144                             // Pointer access and (possibly) value load
145                             if (ptrToMemberOp.resultType().layout() instanceof ValueLayout) {
146                                 Op.Result v = b.op(new PtrLoadValue(memberPtr));
147                                 b.context().mapValue(invokeOp.result(), v);
148                             } else {
149                                 b.context().mapValue(invokeOp.result(), memberPtr);
150                             }
151                         } else {
152                             // @@@
153                             // Value store
154                             throw new UnsupportedOperationException();
155                         }
156                     } else {
157                         b.op(op);
158                     }
159                 } else {
160                     b.op(op);
161                 }
162                 return b;
163             });
164         });
165         return funcOp;
166     }
167 
168 
169 
170     static boolean isBufferOrBufferChild(Class<?> maybeIface) {
171         return  maybeIface.isInterface() && (
172                 MappableIface.class.isAssignableFrom(maybeIface)
173         );
174 
175     }
176     static Schema bufferOrBufferChildSchema(MethodHandles.Lookup l, Class<?> maybeBufferOrBufferChild) {
177         if (isBufferOrBufferChild(maybeBufferOrBufferChild)) {
178             throw new IllegalArgumentException();
179         }
180         Field schemaField;
181         try {
182             schemaField = maybeBufferOrBufferChild.getField("schema");
183            return  (Schema)schemaField.get(null);
184         } catch (NoSuchFieldException | IllegalAccessException e) {
185             throw new RuntimeException(e);
186         }
187     }
188     static Class<?> bufferOrBufferChildClass(MethodHandles.Lookup l, TypeElement t) {
189         try {
190             if (!(t instanceof JavaType jt) || !(jt.resolve(l) instanceof Class<?> c)) {
191                 return null;
192             }
193             return isBufferOrBufferChild(c) ? c : null;
194         } catch (ReflectiveOperationException e) {
195             throw new RuntimeException(e);
196         }
197     }
198     static TypeElement transformStructClassToPtr(MethodHandles.Lookup l, TypeElement type) {
199         if (bufferOrBufferChildClass(l, type) instanceof Class<?> sc) {
200             return new PtrType(bufferOrBufferChildSchema(l, sc));
201         } else {
202             return type;
203         }
204     }
205 
206     public static final class PtrType implements TypeElement {
207         static final String NAME = "ptr";
208         MemoryLayout layout;
209         Schema schema;
210         final JavaType returnType;
211 
212         public PtrType(MemoryLayout layout) {
213             this.layout = layout;
214             this.returnType = switch (layout) {
215                 case StructLayout _ -> JavaType.type(ClassDesc.of(layout.name().orElseThrow()));
216                 case AddressLayout _ -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
217                 case ValueLayout valueLayout -> JavaType.type(valueLayout.carrier());
218                 default -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
219             };
220         }
221         public PtrType(Schema schema) {
222             this.schema = schema;
223             this.layout= null;//schema.layout();
224             this.returnType = switch (layout) {
225                 case StructLayout _ -> JavaType.type(ClassDesc.of(layout.name().orElseThrow()));
226                 case AddressLayout _ -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
227                 case ValueLayout valueLayout -> JavaType.type(valueLayout.carrier());
228                 default -> throw new UnsupportedOperationException("Unsupported member layout: " + layout);
229             };
230         }
231 
232         public JavaType returnType() {
233             return returnType;
234         }
235 
236         public MemoryLayout layout() {
237             return layout;
238         }
239         public Schema schema() {
240             return schema;
241         }
242 
243         @Override
244         public boolean equals(Object o) {
245             if (this == o) return true;
246             if (o == null || getClass() != o.getClass()) return false;
247             PtrType ptrType = (PtrType) o;
248             return Objects.equals(layout, ptrType.layout);
249         }
250 
251         @Override
252         public int hashCode() {
253             return Objects.hash(layout);
254         }
255 
256         @Override
257         public ExternalizedTypeElement externalize() {
258             return new ExternalizedTypeElement(NAME, List.of(returnType.externalize()));
259         }
260 
261         @Override
262         public String toString() {
263             return externalize().toString();
264         }
265     }
266 
267     @OpFactory.OpDeclaration(PtrToMember.NAME)
268     public static final class PtrToMember extends ExternalizableOp {
269         public static final String NAME = "ptr.to.member";
270         public static final String ATTRIBUTE_OFFSET = "offset";
271         public static final String ATTRIBUTE_NAME = "name";
272 
273         final String simpleMemberName;
274         final long memberOffset;
275         final PtrType resultType;
276 
277         PtrToMember(PtrToMember that, CopyContext cc) {
278             super(that, cc);
279             this.simpleMemberName = that.simpleMemberName;
280             this.memberOffset = that.memberOffset;
281             this.resultType = that.resultType;
282         }
283 
284         @Override
285         public PtrToMember transform(CopyContext cc, OpTransformer ot) {
286             return new PtrToMember(this, cc);
287         }
288 
289         public PtrToMember(Value ptr, String simpleMemberName) {
290             super(NAME, List.of(ptr));
291             this.simpleMemberName = simpleMemberName;
292 
293             if (!(ptr.type() instanceof PtrType ptrType)) {
294                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
295             }
296             // @@@ Support group layout
297             if (!(ptrType.layout() instanceof StructLayout structLayout)) {
298                 throw new IllegalArgumentException("Pointer type layout is not a struct layout: " + ptrType.layout());
299             }
300 
301             // Find the actual member name from the simple member name
302             String memberName = findMemberName(structLayout, simpleMemberName);
303             MemoryLayout.PathElement p = MemoryLayout.PathElement.groupElement(memberName);
304             this.memberOffset = structLayout.byteOffset(p);
305             MemoryLayout memberLayout = structLayout.select(p);
306             // Remove any simple member name from the layout
307             MemoryLayout ptrLayout = memberLayout instanceof StructLayout
308                     ? memberLayout.withName(className(memberName))
309                     : memberLayout.withoutName();
310             this.resultType = new PtrType(ptrLayout);
311         }
312 
313         // @@@ Change to return member index
314         static String findMemberName(StructLayout sl, String simpleMemberName) {
315             for (MemoryLayout layout : sl.memberLayouts()) {
316                 String memberName = layout.name().orElseThrow();
317                 if (simpleMemberName(memberName).equals(simpleMemberName)) {
318                     return memberName;
319                 }
320             }
321             throw new NoSuchElementException("No member found: " + simpleMemberName + " " + sl);
322         }
323 
324         static String simpleMemberName(String memberName) {
325             int i = memberName.indexOf("::");
326             return i != -1
327                     ? memberName.substring(i + 2)
328                     : memberName;
329         }
330 
331         static String className(String memberName) {
332             int i = memberName.indexOf("::");
333             return i != -1
334                     ? memberName.substring(0, i)
335                     : null;
336         }
337 
338         @Override
339         public PtrType resultType() {
340             return resultType;
341         }
342 
343         @Override
344         public Map<String, Object> attributes() {
345             HashMap<String, Object> attrs = new HashMap<>(super.attributes());
346             attrs.put("", simpleMemberName);
347             attrs.put(ATTRIBUTE_OFFSET, memberOffset);
348             return attrs;
349         }
350 
351         public String simpleMemberName() {
352             return simpleMemberName;
353         }
354 
355         public long memberOffset() {
356             return memberOffset;
357         }
358 
359         public Value ptrValue() {
360             return operands().get(0);
361         }
362     }
363 
364 
365     @OpFactory.OpDeclaration(PtrToMember.NAME)
366     public static final class PtrAddOffset extends Op {
367         public static final String NAME = "ptr.add.offset";
368 
369         PtrAddOffset(PtrAddOffset that, CopyContext cc) {
370             super(that, cc);
371         }
372 
373         @Override
374         public PtrAddOffset transform(CopyContext cc, OpTransformer ot) {
375             return new PtrAddOffset(this, cc);
376         }
377 
378         public PtrAddOffset(Value ptr, Value offset) {
379             super(NAME, List.of(ptr, offset));
380 
381             if (!(ptr.type() instanceof PtrType)) {
382                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
383             }
384             if (!(offset.type() instanceof PrimitiveType pt && pt.equals(JavaType.LONG))) {
385                 throw new IllegalArgumentException("Offset value is not of primitve long type: " + offset.type());
386             }
387         }
388 
389         @Override
390         public TypeElement resultType() {
391             return ptrValue().type();
392         }
393 
394         public Value ptrValue() {
395             return operands().get(0);
396         }
397 
398         public Value offsetValue() {
399             return operands().get(1);
400         }
401     }
402 
403     @OpFactory.OpDeclaration(PtrToMember.NAME)
404     public static final class PtrLoadValue extends Op {
405         public static final String NAME = "ptr.load.value";
406 
407         final JavaType resultType;
408 
409         PtrLoadValue(PtrLoadValue that, CopyContext cc) {
410             super(that, cc);
411             this.resultType = that.resultType;
412         }
413 
414         @Override
415         public PtrLoadValue transform(CopyContext cc, OpTransformer ot) {
416             return new PtrLoadValue(this, cc);
417         }
418 
419         public PtrLoadValue(Value ptr) {
420             super(NAME, List.of(ptr));
421 
422             if (!(ptr.type() instanceof PtrType ptrType)) {
423                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
424             }
425             if (!(ptrType.layout() instanceof ValueLayout)) {
426                 throw new IllegalArgumentException("Pointer type layout is not a value layout: " + ptrType.layout());
427             }
428             this.resultType = ptrType.returnType();
429         }
430 
431         @Override
432         public TypeElement resultType() {
433             return resultType;
434         }
435 
436         public Value ptrValue() {
437             return operands().get(0);
438         }
439     }
440 
441     @OpFactory.OpDeclaration(PtrToMember.NAME)
442     public static final class PtrStoreValue extends Op {
443         public static final String NAME = "ptr.store.value";
444 
445         PtrStoreValue(PtrStoreValue that, CopyContext cc) {
446             super(that, cc);
447         }
448 
449         @Override
450         public PtrStoreValue transform(CopyContext cc, OpTransformer ot) {
451             return new PtrStoreValue(this, cc);
452         }
453 
454         public PtrStoreValue(Value ptr, Value v) {
455             super(NAME, List.of(ptr));
456 
457             if (!(ptr.type() instanceof PtrType ptrType)) {
458                 throw new IllegalArgumentException("Pointer value is not of pointer type: " + ptr.type());
459             }
460             if (!(ptrType.layout() instanceof ValueLayout)) {
461                 throw new IllegalArgumentException("Pointer type layout is not a value layout: " + ptrType.layout());
462             }
463             if (!(ptrType.returnType().equals(v.type()))) {
464                 throw new IllegalArgumentException("Pointer reference type is not same as value to store type: "
465                         + ptrType.returnType() + " " + v.type());
466             }
467         }
468 
469         @Override
470         public TypeElement resultType() {
471             return JavaType.VOID;
472         }
473 
474         public Value ptrValue() {
475             return operands().get(0);
476         }
477     }
478 }
479