1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.ifacemapper;
 27 
 28 
 29 import hat.ifacemapper.accessor.AccessorInfo;
 30 import hat.ifacemapper.accessor.Accessors;
 31 import hat.ifacemapper.accessor.ValueType;
 32 import hat.ifacemapper.component.Util;
 33 //import jdk.internal.vm.annotation.Stable;
 34 
 35 import java.io.IOException;
 36 import java.lang.classfile.ClassFile;
 37 import java.lang.classfile.ClassHierarchyResolver;
 38 import java.lang.constant.ClassDesc;
 39 import java.lang.foreign.GroupLayout;
 40 import java.lang.foreign.MemorySegment;
 41 import java.lang.invoke.MethodHandle;
 42 import java.lang.invoke.MethodHandles;
 43 import java.lang.invoke.MethodType;
 44 import java.nio.file.Files;
 45 import java.nio.file.Path;
 46 import java.nio.file.StandardOpenOption;
 47 import java.util.ArrayList;
 48 import java.util.Comparator;
 49 import java.util.List;
 50 import java.util.Optional;
 51 import java.util.OptionalLong;
 52 import java.util.Set;
 53 import java.util.function.Function;
 54 
 55 import static java.lang.classfile.ClassFile.ClassHierarchyResolverOption;
 56 
 57 /**
 58  * A mapper that is matching components of an interface with elements in a GroupLayout.
 59  */
 60 public final class SegmentInterfaceMapper<T>
 61         extends AbstractSegmentMapper<T>
 62         implements SegmentMapper<T> {
 63 
 64     private static final MethodHandles.Lookup LOCAL_LOOKUP = MethodHandles.lookup();
 65 
 66    // @Stable
 67     private final Class<T> implClass;
 68    // @Stable
 69     private final MethodHandle getHandle;
 70    // @Stable
 71     private final MethodHandle setHandle;
 72    // @Stable
 73     // Capability to extract the segment from an instance of the generated implClass
 74     private final MethodHandle segmentGetHandle;
 75    // @Stable
 76     // Capability to extract the offset from an instance of the generated implClass
 77     private final MethodHandle offsetGetHandle;
 78     private final List<AffectedMemory> affectedMemories;
 79 
 80     private SegmentInterfaceMapper(MethodHandles.Lookup lookup,
 81                                    Class<T> type,
 82                                    GroupLayout layout,
 83                                    BoundSchema<?> boundSchema,
 84                                    boolean leaf,
 85                                    List<AffectedMemory> affectedMemories) {
 86         super(lookup, type, layout, boundSchema,leaf,
 87                 MapperUtil::requireImplementableInterfaceType, Accessors::ofInterface);
 88         this.affectedMemories = affectedMemories;
 89 
 90         // Add affected memory for all the setters seen on this level (mutation)
 91         accessors().stream(AccessorInfo.AccessorType.SETTER)
 92                 .map(AffectedMemory::from)
 93                 .forEach(affectedMemories::add);
 94 
 95         this.implClass = generateClass();
 96         this.getHandle = computeGetHandle();
 97         this.setHandle = computeSetHandle();
 98 
 99         try {
100             this.segmentGetHandle = lookup.unreflect(implClass.getMethod(MapperUtil.SECRET_SEGMENT_METHOD_NAME))
101                     .asType(MethodType.methodType(MemorySegment.class, Object.class));
102             this.offsetGetHandle = lookup.unreflect(implClass.getMethod(MapperUtil.SECRET_OFFSET_METHOD_NAME))
103                     .asType(MethodType.methodType(long.class, Object.class));
104         } catch (ReflectiveOperationException e) {
105             throw new RuntimeException(e);
106         }
107         // No need for this now
108         this.accessors = null;
109     }
110 
111     @Override
112     public MethodHandle getHandle() {
113         return getHandle;
114     }
115 
116     @Override
117     public MethodHandle setHandle() {
118         return setHandle;
119     }
120 
121     @Override
122     public Optional<MemorySegment> segment(T source) {
123         if (implClass == source.getClass()) {
124             try {
125                 return Optional.of((MemorySegment) segmentGetHandle.invokeExact(source));
126             } catch (Throwable _) {
127             }
128         }
129         return Optional.empty();
130     }
131 
132     @Override
133     public OptionalLong offset(T source) {
134         // Implicit null check of source
135         if (implClass == source.getClass()) {
136             try {
137                 return OptionalLong.of((long) offsetGetHandle.invokeExact(source));
138             } catch (Throwable _) {
139             }
140         }
141         return OptionalLong.empty();
142     }
143 
144     @Override
145     public <R> SegmentMapper<R> map(Class<R> newType, Function<? super T, ? extends R> toMapper) {
146         return new Mapped<>(lookup(), newType ,layout(),boundSchema(), getHandle(), toMapper);
147     }
148 
149     // @Override
150     //  public <R> SegmentMapper<R> map(Class<R> newType,
151     //Function<? super T, ? extends R> toMapper,
152     // Function<? super R, ? extends T> fromMapper) {
153     //  throw twoWayMappersUnsupported();
154     //  }
155 
156     @Override
157     protected MethodHandle computeGetHandle() {
158         try {
159             // (MemorySegment, long)void
160             var ctor = lookup().findConstructor(implClass, MethodType.methodType(void.class, MemorySegment.class, GroupLayout.class, BoundSchema.class,
161             long.class));
162 
163             // try? var ctor = lookup().findConstructor(implClass, MethodType.methodType(void.class, MemorySegment.class, long.class));
164             // -> (MemorySegment, long)Object
165             ctor = ctor.asType(ctor.type().changeReturnType(Object.class));
166             return ctor;
167         } catch (ReflectiveOperationException e) {
168             throw new IllegalArgumentException("Unable to find constructor for " + implClass, e);
169         }
170     }
171 
172     // This method will return a MethodHandle that will update memory that
173     // is mapped to a setter. Memory that is not mapped to a setter will be
174     // unaffected.
175     @Override
176     protected MethodHandle computeSetHandle() {
177         List<AffectedMemory> fragments = affectedMemories.stream()
178                 .sorted(Comparator.comparingLong(AffectedMemory::offset))
179                 .toList();
180 
181         fragments = AffectedMemory.coalesce(fragments);
182 
183         try {
184             return switch (fragments.size()) {
185                 case 0 -> MethodHandles.empty(Util.SET_TYPE);
186                 case 1 -> {
187                     MethodType mt = MethodType.methodType(void.class, MemorySegment.class, long.class, Object.class);
188                     yield LOCAL_LOOKUP.findVirtual(SegmentInterfaceMapper.class, "setAll", mt)
189                             .bindTo(this);
190                 }
191                 default -> {
192                     MethodType mt = MethodType.methodType(void.class, MemorySegment.class, long.class, Object.class, List.class);
193                     MethodHandle mh = LOCAL_LOOKUP.findVirtual(SegmentInterfaceMapper.class, "setFragments", mt)
194                             .bindTo(this);
195                     yield MethodHandles.insertArguments(mh, 3, fragments);
196                 }
197             };
198         } catch (ReflectiveOperationException e) {
199             throw new IllegalArgumentException("Unable to find setter", e);
200         }
201     }
202 
203     List<AffectedMemory> affectedMemories() {
204         return affectedMemories;
205     }
206 
207     // Private methods and classes
208 
209     private Class<T> generateClass() {
210         String packageName = lookup().lookupClass().getPackageName();
211         String className = packageName.isEmpty()
212                 ? ""
213                 : packageName + ".";
214         className = className + type().getSimpleName() + "InterfaceMapper";
215         ClassDesc classDesc = ClassDesc.of(className);
216         ClassLoader loader = type().getClassLoader();
217 
218         // We need to materialize these methods so that the order is preserved
219         // during generation of the class.
220         List<AccessorInfo> virtualMethods = accessors().stream()
221                 .filter(mi -> mi.key().valueType().equals(ValueType.INTERFACE))
222                 .toList();
223 
224         byte[] bytes = ClassFile.of(ClassHierarchyResolverOption.of(ClassHierarchyResolver.ofClassLoading(loader)))
225                 .build(classDesc, cb -> {
226                     ByteCodeGenerator generator = ByteCodeGenerator.of(type(), classDesc, cb);
227 
228                     // public final XxInterfaceMapper implements Xx {
229                     //     private final MemorySegment segment;
230                     //     private final long offset;
231                     generator.classDefinition();
232 
233                     // void XxInterfaceMapper(MemorySegment segment, long offset) {
234                     //    this.segment = segment;
235                     //    this.offset = offset;
236                     // }
237                     generator.constructor(layout().byteSize());
238 
239                     // MemorySegment $_$_$sEgMeNt$_$_$() {
240                     //     return segment;
241                     // }
242                     generator.obscuredSegment();
243 
244                     // MemorySegment $_$_$lAyOuT$_$_$() {
245                     //     return layout;
246                     // }
247                     generator.obscuredLayout();
248                     // MemorySegment $_$_$bOuNdScHeMa$_$_$() {
249                     //     return layout;
250                     // }
251                     generator.obscuredBoundSchema();
252 
253                     // long $_$_$oFfSeT$_$_$() {
254                     //     return offset;
255                     // }
256                     generator.obscuredOffset();
257 
258                     // @Override
259                     // <t> gX(c1, c2, ..., cN) {
260                     //     long indexOffset = f(dimensions, c1, c2, ..., long cN);
261                     //     return segment.get(JAVA_t, offset + elementOffset + indexOffset);
262                     // }
263                     accessors().stream(Set.of(AccessorInfo.Key.SCALAR_VALUE_GETTER, AccessorInfo.Key.ARRAY_VALUE_GETTER))
264                             .forEach(generator::valueGetter);
265 
266                     // @Override
267                     // void gX(c1, c2, ..., cN, <t> t) {
268                     //     long indexOffset = f(dimensions, c1, c2, ..., long cN);
269                     //     segment.set(JAVA_t, offset + elementOffset + indexOffset, t);
270                     // }
271                     accessors().stream(Set.of(AccessorInfo.Key.SCALAR_VALUE_SETTER, AccessorInfo.Key.ARRAY_VALUE_SETTER))
272                             .forEach(generator::valueSetter);
273 
274                     for (int i = 0; i < virtualMethods.size(); i++) {
275                         AccessorInfo a = virtualMethods.get(i);
276                         switch (a.key().accessorType()) {
277                             // @Override
278                             // <T> T gX(long c1, long c2, ..., long cN) {
279                             //     long indexOffset = f(dimensions, c1, c2, ..., long cN);
280                             //     return (T) mh[x].invokeExact(segment, offset + elementOffset + indexOffset);
281                             // }
282                             case GETTER -> generator.invokeVirtualGetter(a, i);
283                             // @Override
284                             // <T> void gX(T t) {
285                             //     long indexOffset = f(dimensions, c1, c2, ..., long cN);
286                             //     mh[x].invokeExact(segment, offset + elementOffset + indexOffset, t);
287                             // }
288                             case SETTER -> generator.invokeVirtualSetter(a, i);
289                         }
290                     }
291 
292                     // @Override
293                     // int hashCode() {
294                     //     return System.identityHashCode(this);
295                     // }
296                     generator.hashCode_();
297 
298                     // @Override
299                     // boolean equals(Object o) {
300                     //     return this == o;
301                     // }
302                     generator.equals_();
303 
304                     //  @Override
305                     //  public String toString() {
306                     //      return "Foo[g0()=" + g0() + ", g1()=" + g1() + ... "]";
307                     //  }
308                     List<AccessorInfo> getters = accessors().stream(AccessorInfo.AccessorType.GETTER)
309                             .toList();
310                     generator.toString_(getters);
311                 });
312         try {
313             List<MethodHandle> classData = virtualMethods.stream()
314                     .map(a -> switch (a.key()) {
315                                 case SCALAR_INTERFACE_GETTER, ARRAY_INTERFACE_GETTER ->
316                                         mapperCache().interfaceGetMethodHandleFor(a, affectedMemories::add);
317                                 default -> throw new InternalError("Should not reach here " + a);
318                             }
319                     )
320                     .toList();
321 
322             if (MapperUtil.isDebug()) {
323                 Path path = Path.of(classDesc.displayName() + ".class");
324                 try {
325                     Files.write(path, bytes, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
326                     System.out.println("Wrote class file " + path.toAbsolutePath());
327                 } catch (IOException e) {
328                     System.out.println("Unable to write class file: " + path.toAbsolutePath() + " " + e.getMessage());
329                 }
330             }
331 
332             @SuppressWarnings("unchecked")
333             Class<T> c = (Class<T>) lookup()
334                     .defineHiddenClassWithClassData(bytes, classData, true)
335                     .lookupClass();
336             return c;
337         } catch (IllegalAccessException | VerifyError e) {
338             throw new IllegalArgumentException("Unable to define proxy class for " + type() + " using " + layout(), e);
339         }
340     }
341 
342     private MethodHandle changeReturnTypeToObject(MethodHandle mh) {
343         return mh.asType(mh.type().changeReturnType(Object.class));
344     }
345 
346     private MethodHandle changeParam2ToObject(MethodHandle mh) {
347         return mh.asType(mh.type().changeParameterType(2, Object.class));
348     }
349 
350     // Invoked reflectively
351     private void setAll(MemorySegment segment, long offset, T t) {
352         MemorySegment srcSegment = segment(t)
353                 .orElseThrow(SegmentInterfaceMapper::notImplType);
354         long srcOffset = offset(t)
355                 .orElseThrow(SegmentInterfaceMapper::notImplType);
356         MemorySegment.copy(srcSegment, srcOffset, segment, offset, layout().byteSize());
357     }
358 
359     // Invoked reflectively
360     private void setFragments(MemorySegment segment, long offset, T t, List<AffectedMemory> fragments) {
361         MemorySegment srcSegment = segment(t)
362                 .orElseThrow(SegmentInterfaceMapper::notImplType);
363         long srcOffset = offset(t)
364                 .orElseThrow(SegmentInterfaceMapper::notImplType);
365         for (AffectedMemory m : fragments) {
366             MemorySegment.copy(srcSegment, srcOffset + m.offset(), segment, offset + m.offset(), m.size());
367         }
368     }
369 
370     private static IllegalArgumentException notImplType() {
371         return new IllegalArgumentException("The provided object of type T is not created by this mapper.");
372     }
373 
374     // Used to keep track of which memory shards gets accessed
375     // by setters. We need this when computing the setHandle
376     record AffectedMemory(long offset,
377                           long size) {
378 
379         //  AffectedMemory {
380         //   long offset;
381         //  long size; // requireNonNegative(offset);
382         // requireNonNegative(size);
383         // }
384 
385         static AffectedMemory from(AccessorInfo mi) {
386             return new AffectedMemory(mi.offset(), mi.layoutInfo().layout().byteSize());
387         }
388 
389         AffectedMemory translate(long delta) {
390             return new AffectedMemory(offset() + delta, size());
391         }
392 
393         static List<AffectedMemory> coalesce(List<AffectedMemory> items) {
394             List<AffectedMemory> result = new ArrayList<>();
395 
396             for (int i = 0; i < items.size(); i++) {
397                 AffectedMemory current = items.get(i);
398                 for (int j = i + 1; j < result.size(); j++) {
399                     AffectedMemory next = items.get(j);
400                     if (current.isBefore(next)) {
401                         current = current.merge(next);
402                     } else {
403                         break;
404                     }
405                 }
406                 result.add(current);
407             }
408             return result;
409         }
410 
411         private boolean isBefore(AffectedMemory other) {
412             return offset + size == other.offset();
413         }
414 
415         private AffectedMemory merge(AffectedMemory other) {
416             return new AffectedMemory(offset, size + other.size());
417         }
418 
419     }
420 
421     public static <T> SegmentInterfaceMapper<T> create(MethodHandles.Lookup lookup,
422                                                        Class<T> type,
423                                                        GroupLayout layout,
424                                                        BoundSchema<?> boundSchema) {
425         return new SegmentInterfaceMapper<>(lookup, type,  layout, boundSchema, false, new ArrayList<>());
426     }
427 
428     // Mapping
429 
430     /**
431      * This class models composed record mappers.
432      *
433      * @param lookup    to use for reflective operations
434      * @param type      new type to map to/from
435      * @param layout    original layout
436      * @param getHandle for get operations
437      * @param toMapper  a function that goes from T to R
438      * @param <T>       original mapper type
439      * @param <R>       composed mapper type
440      */
441     record Mapped<T, R>(
442             MethodHandles.Lookup lookup,
443             @Override Class<R> type,
444             @Override GroupLayout layout,
445             @Override BoundSchema<?> boundSchema,
446             @Override MethodHandle getHandle,
447             Function<? super T, ? extends R> toMapper
448     ) implements SegmentMapper<R> {
449 
450         static final MethodHandle SET_OPERATIONS_UNSUPPORTED;
451 
452         static {
453             try {
454                 MethodType methodType = MethodType.methodType(void.class, MemorySegment.class, long.class, Object.class);
455                 SET_OPERATIONS_UNSUPPORTED = LOCAL_LOOKUP.findStatic(Mapped.class, "setOperationsUnsupported", methodType);
456             } catch (ReflectiveOperationException e) {
457                 throw new ExceptionInInitializerError(e);
458             }
459         }
460 
461         Mapped(MethodHandles.Lookup lookup,
462                Class<R> type,
463                GroupLayout layout,
464                BoundSchema<?> boundSchema,
465                MethodHandle getHandle,
466                Function<? super T, ? extends R> toMapper
467         ) {
468             this.lookup = lookup;
469             this.type = type;
470             this.boundSchema =boundSchema;
471             this.layout = layout;
472             this.toMapper = toMapper;
473             MethodHandle toMh = findVirtual("mapTo").bindTo(this);
474             this.getHandle = MethodHandles.filterReturnValue(getHandle, toMh);
475         }
476 
477         @Override
478         public MethodHandle setHandle() {
479             return SET_OPERATIONS_UNSUPPORTED;
480         }
481 
482         @Override
483         public <R1> SegmentMapper<R1> map(Class<R1> newType,
484                                           Function<? super R, ? extends R1> toMapper) {
485             return new Mapped<>(lookup, newType,  layout(), boundSchema(), getHandle(), toMapper);
486         }
487 
488         // Used reflective when obtaining a MethodHandle
489         R mapTo(T t) {
490             return toMapper.apply(t);
491         }
492 
493         // Used reflective when obtaining a MethodHandle
494         /*T mapFrom(R r) {
495             return fromMapper.apply(r);
496         }*/
497 
498         private static MethodHandle findVirtual(String name) {
499             try {
500                 var mt = MethodType.methodType(Object.class, Object.class);
501                 return LOCAL_LOOKUP.findVirtual(Mapped.class, name, mt);
502             } catch (ReflectiveOperationException e) {
503                 // Should not happen
504                 throw new InternalError(e);
505             }
506         }
507 
508         private static void setOperationsUnsupported(MemorySegment s, long o, Object t) {
509             throw new UnsupportedOperationException("SegmentMapper::set operations are not supported for mapped interface mappers");
510         }
511 
512     }
513 
514 
515 }