1 /*
  2  *  Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.  Oracle designates this
  8  *  particular file as subject to the "Classpath" exception as provided
  9  *  by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  *  This code is distributed in the hope that it will be useful, but WITHOUT
 12  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  *  version 2 for more details (a copy is included in the LICENSE file that
 15  *  accompanied this code).
 16  *
 17  *  You should have received a copy of the GNU General Public License version
 18  *  2 along with this work; if not, write to the Free Software Foundation,
 19  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  *  or visit www.oracle.com if you need additional information or have any
 23  *  questions.
 24  *
 25  */
 26 package jdk.internal.foreign;
 27 
 28 import jdk.internal.vm.annotation.ForceInline;
 29 import jdk.internal.vm.annotation.Stable;
 30 
 31 import java.lang.foreign.AddressLayout;
 32 import java.lang.foreign.GroupLayout;
 33 import java.lang.foreign.MemoryLayout;
 34 import java.lang.foreign.MemorySegment;
 35 import java.lang.foreign.SequenceLayout;
 36 import java.lang.foreign.StructLayout;
 37 import java.lang.foreign.ValueLayout;
 38 import java.lang.invoke.MethodHandle;
 39 import java.lang.invoke.MethodHandles;
 40 import java.lang.invoke.MethodType;
 41 import java.lang.invoke.VarHandle;
 42 import java.util.Arrays;
 43 import java.util.List;
 44 import java.util.Locale;
 45 import java.util.Objects;
 46 import java.util.function.UnaryOperator;
 47 import java.util.stream.IntStream;
 48 import java.util.stream.Stream;
 49 
 50 import static java.util.stream.Collectors.joining;
 51 
 52 /**
 53  * This class provide support for constructing layout paths; that is, starting from a root path (see {@link #rootPath(MemoryLayout)},
 54  * a path can be constructed by selecting layout elements using the selector methods provided by this class
 55  * (see {@link #sequenceElement()}, {@link #sequenceElement(long)}, {@link #sequenceElement(long, long)}, {@link #groupElement(String)}).
 56  * Once a path has been fully constructed, clients can ask for the offset associated with the layout element selected
 57  * by the path (see {@link #offset}), or obtain var handle to access the selected layout element
 58  * given an address pointing to a segment associated with the root layout (see {@link #dereferenceHandle()}).
 59  */
 60 public class LayoutPath {
 61 
 62     private static final long[] EMPTY_STRIDES = new long[0];
 63     private static final long[] EMPTY_BOUNDS = new long[0];
 64     private static final MethodHandle[] EMPTY_DEREF_HANDLES = new MethodHandle[0];
 65 
 66     private static final MethodHandle MH_ADD_SCALED_OFFSET;
 67     private static final MethodHandle MH_SLICE;
 68     private static final MethodHandle MH_SLICE_LAYOUT;
 69     private static final MethodHandle MH_CHECK_ALIGN;
 70     private static final MethodHandle MH_SEGMENT_RESIZE;
 71     private static final MethodHandle MH_ADD;
 72 
 73     static {
 74         try {
 75             MethodHandles.Lookup lookup = MethodHandles.lookup();
 76             MH_ADD_SCALED_OFFSET = lookup.findStatic(LayoutPath.class, "addScaledOffset",
 77                     MethodType.methodType(long.class, long.class, long.class, long.class, long.class));
 78             MH_SLICE = lookup.findVirtual(MemorySegment.class, "asSlice",
 79                     MethodType.methodType(MemorySegment.class, long.class, long.class));
 80             MH_SLICE_LAYOUT = lookup.findVirtual(MemorySegment.class, "asSlice",
 81                     MethodType.methodType(MemorySegment.class, long.class, MemoryLayout.class));
 82             MH_CHECK_ALIGN = lookup.findStatic(LayoutPath.class, "checkAlign",
 83                     MethodType.methodType(void.class, MemorySegment.class, long.class, MemoryLayout.class));
 84             MH_SEGMENT_RESIZE = lookup.findStatic(LayoutPath.class, "resizeSegment",
 85                     MethodType.methodType(MemorySegment.class, MemorySegment.class, MemoryLayout.class));
 86             MH_ADD = lookup.findStatic(Long.class, "sum",
 87                     MethodType.methodType(long.class, long.class, long.class));
 88         } catch (Throwable ex) {
 89             throw new ExceptionInInitializerError(ex);
 90         }
 91     }
 92 
 93     private final MemoryLayout layout;
 94     private final long offset;
 95     private final LayoutPath enclosing;
 96     private final long[] strides;
 97     private final long[] bounds;
 98     private final MethodHandle[] derefAdapters;
 99 
100     private LayoutPath(MemoryLayout layout, long offset, long[] strides, long[] bounds, MethodHandle[] derefAdapters, LayoutPath enclosing) {
101         this.layout = layout;
102         this.offset = offset;
103         this.strides = strides;
104         this.bounds = bounds;
105         this.derefAdapters = derefAdapters;
106         this.enclosing = enclosing;
107     }
108 
109     // Layout path selector methods
110 
111     public LayoutPath sequenceElement() {
112         SequenceLayout seq = requireSequenceLayout();
113         MemoryLayout elem = seq.elementLayout();
114         return LayoutPath.nestedPath(elem, offset, addStride(elem.byteSize()), addBound(seq.elementCount()), derefAdapters, this);
115     }
116 
117     public LayoutPath sequenceElement(long start, long step) {
118         SequenceLayout seq = requireSequenceLayout();
119         checkSequenceBounds(seq, start);
120         MemoryLayout elem = seq.elementLayout();
121         long elemSize = elem.byteSize();
122         long nelems = step > 0 ?
123                 seq.elementCount() - start :
124                 start + 1;
125         long maxIndex = Math.ceilDiv(nelems, Math.abs(step));
126         return LayoutPath.nestedPath(elem, offset + (start * elemSize),
127                 addStride(elemSize * step), addBound(maxIndex), derefAdapters, this);
128     }
129 
130     public LayoutPath sequenceElement(long index) {
131         SequenceLayout seq = requireSequenceLayout();
132         checkSequenceBounds(seq, index);
133         long elemSize = seq.elementLayout().byteSize();
134         long elemOffset = elemSize * index;
135         return LayoutPath.nestedPath(seq.elementLayout(), offset + elemOffset, strides, bounds, derefAdapters, this);
136     }
137 
138     public LayoutPath groupElement(String name) {
139         GroupLayout g = requireGroupLayout();
140         long offset = 0;
141         MemoryLayout elem = null;
142         for (int i = 0; i < g.memberLayouts().size(); i++) {
143             MemoryLayout l = g.memberLayouts().get(i);
144             if (l.name().isPresent() &&
145                 l.name().get().equals(name)) {
146                 elem = l;
147                 break;
148             } else if (g instanceof StructLayout) {
149                 offset += l.byteSize();
150             }
151         }
152         if (elem == null) {
153             throw badLayoutPath(
154                     String.format("cannot resolve '%s' in layout %s", name, breadcrumbs()));
155         }
156         return LayoutPath.nestedPath(elem, this.offset + offset, strides, bounds, derefAdapters, this);
157     }
158 
159     public LayoutPath groupElement(long index) {
160         GroupLayout g = requireGroupLayout();
161         long elemSize = g.memberLayouts().size();
162         long offset = 0;
163         MemoryLayout elem = null;
164         for (int i = 0; i <= index; i++) {
165             if (i == elemSize) {
166                 throw badLayoutPath(
167                         String.format("cannot resolve element %d in layout: %s", index, breadcrumbs()));
168             }
169             elem = g.memberLayouts().get(i);
170             if (g instanceof StructLayout && i < index) {
171                 offset += elem.byteSize();
172             }
173         }
174         return LayoutPath.nestedPath(elem, this.offset + offset, strides, bounds, derefAdapters, this);
175     }
176 
177     public LayoutPath derefElement() {
178         if (!(layout instanceof AddressLayout addressLayout) ||
179                 addressLayout.targetLayout().isEmpty()) {
180             throw badLayoutPath(
181                     String.format("Cannot dereference layout: %s", breadcrumbs()));
182         }
183         MemoryLayout derefLayout = addressLayout.targetLayout().get();
184         MethodHandle handle = dereferenceHandle(false).toMethodHandle(VarHandle.AccessMode.GET);
185         handle = MethodHandles.filterReturnValue(handle,
186                 MethodHandles.insertArguments(MH_SEGMENT_RESIZE, 1, derefLayout));
187         return derefPath(derefLayout, handle, this);
188     }
189 
190     private static MemorySegment resizeSegment(MemorySegment segment, MemoryLayout layout) {
191         return Utils.longToAddress(segment.address(), layout.byteSize(), layout.byteAlignment());
192     }
193 
194     // Layout path projections
195 
196     public long offset() {
197         return offset;
198     }
199 
200     public VarHandle dereferenceHandle() {
201         return dereferenceHandle(true);
202     }
203 
204     public VarHandle dereferenceHandle(boolean adapt) {
205         if (!(layout instanceof ValueLayout valueLayout)) {
206             throw new IllegalArgumentException(
207                     String.format("Path does not select a value layout: %s", breadcrumbs()));
208         }
209 
210         // If we have an enclosing layout, drop the alignment check for the accessed element,
211         // we check the root layout instead
212         ValueLayout accessedLayout = enclosing != null ? valueLayout.withByteAlignment(1) : valueLayout;
213         VarHandle handle = accessedLayout.varHandle();
214         handle = MethodHandles.collectCoordinates(handle, 1, offsetHandle());
215 
216         // we only have to check the alignment of the root layout for the first dereference we do,
217         // as each dereference checks the alignment of the target address when constructing its segment
218         // (see Utils::longToAddress)
219         if (derefAdapters.length == 0 && enclosing != null) {
220             // insert align check for the root layout on the initial MS + offset
221             List<Class<?>> coordinateTypes = handle.coordinateTypes();
222             MethodHandle alignCheck = MethodHandles.insertArguments(MH_CHECK_ALIGN, 2, rootLayout());
223             handle = MethodHandles.collectCoordinates(handle, 0, alignCheck);
224             int[] reorder = IntStream.concat(IntStream.of(0, 1), IntStream.range(0, coordinateTypes.size())).toArray();
225             handle = MethodHandles.permuteCoordinates(handle, coordinateTypes, reorder);
226         }
227 
228         if (adapt) {
229             if (derefAdapters.length > 0) {
230                 // plug up the base offset if we have at least 1 enclosing dereference
231                 handle = MethodHandles.insertCoordinates(handle, 1, 0);
232             }
233             for (int i = derefAdapters.length; i > 0; i--) {
234                 MethodHandle adapter = derefAdapters[i - 1];
235                 // the first/outermost adapter will have a base offset coordinate, the rest are constant 0
236                 if (i > 1) {
237                     // plug in a constant 0 base offset for all but the outermost access in a deref chain
238                     adapter = MethodHandles.insertArguments(adapter, 1, 0);
239                 }
240                 handle = MethodHandles.collectCoordinates(handle, 0, adapter);
241             }
242         }
243         return handle;
244     }
245 
246     @ForceInline
247     private static long addScaledOffset(long base, long index, long stride, long bound) {
248         Objects.checkIndex(index, bound);
249         return base + (stride * index);
250     }
251 
252     public MethodHandle offsetHandle() {
253         MethodHandle mh = MethodHandles.insertArguments(MH_ADD, 0, offset);
254         for (int i = strides.length - 1; i >= 0; i--) {
255             MethodHandle collector = MethodHandles.insertArguments(MH_ADD_SCALED_OFFSET, 2, strides[i], bounds[i]);
256             // (J, ...) -> J to (J, J, ...) -> J
257             // i.e. new coord is prefixed. Last coord will correspond to innermost layout
258             mh = MethodHandles.collectArguments(mh, 0, collector);
259         }
260 
261         return mh;
262     }
263 
264     private MemoryLayout rootLayout() {
265         return enclosing != null ? enclosing.rootLayout() : this.layout;
266     }
267 
268     public MethodHandle sliceHandle() {
269         MethodHandle sliceHandle;
270         if (enclosing != null) {
271             // drop the alignment check for the accessed element, we check the root layout instead
272             sliceHandle = MH_SLICE; // (MS, long, long) -> MS
273             sliceHandle = MethodHandles.insertArguments(sliceHandle, 2, layout.byteSize()); // (MS, long) -> MS
274         } else {
275             sliceHandle = MH_SLICE_LAYOUT; // (MS, long, MemoryLayout) -> MS
276             sliceHandle = MethodHandles.insertArguments(sliceHandle, 2, layout); // (MS, long) -> MS
277         }
278         sliceHandle = MethodHandles.collectArguments(sliceHandle, 1, offsetHandle()); // (MS, long, ...) -> MS
279 
280         if (enclosing != null) {
281             // insert align check for the root layout on the initial MS + offset
282             MethodType oldType = sliceHandle.type();
283             MethodHandle alignCheck = MethodHandles.insertArguments(MH_CHECK_ALIGN, 2, rootLayout());
284             sliceHandle = MethodHandles.collectArguments(sliceHandle, 0, alignCheck); // (MS, long, MS, long) -> MS
285             int[] reorder = IntStream.concat(IntStream.of(0, 1), IntStream.range(0, oldType.parameterCount())).toArray();
286             sliceHandle = MethodHandles.permuteArguments(sliceHandle, oldType, reorder); // (MS, long, ...) -> MS
287         }
288 
289         return sliceHandle;
290     }
291 
292     private static void checkAlign(MemorySegment segment, long offset, MemoryLayout constraint) {
293         if (!((AbstractMemorySegmentImpl) segment).isAlignedForElement(offset, constraint)) {
294             throw new IllegalArgumentException(String.format(
295                     "Target offset %d is incompatible with alignment constraint %d (of %s) for segment %s"
296                     , offset, constraint.byteAlignment(), constraint, segment));
297         }
298     }
299 
300     public MemoryLayout layout() {
301         return layout;
302     }
303 
304     // Layout path construction
305 
306     public static LayoutPath rootPath(MemoryLayout layout) {
307         return new LayoutPath(layout, 0L, EMPTY_STRIDES, EMPTY_BOUNDS, EMPTY_DEREF_HANDLES, null);
308     }
309 
310     private static LayoutPath nestedPath(MemoryLayout layout, long offset, long[] strides, long[] bounds, MethodHandle[] derefAdapters, LayoutPath encl) {
311         return new LayoutPath(layout, offset, strides, bounds, derefAdapters, encl);
312     }
313 
314     private static LayoutPath derefPath(MemoryLayout layout, MethodHandle handle, LayoutPath encl) {
315         MethodHandle[] handles = Arrays.copyOf(encl.derefAdapters, encl.derefAdapters.length + 1);
316         handles[encl.derefAdapters.length] = handle;
317         return new LayoutPath(layout, 0L, EMPTY_STRIDES, EMPTY_BOUNDS, handles, null);
318     }
319 
320     // Helper methods
321 
322     private SequenceLayout requireSequenceLayout() {
323         return requireLayoutType(SequenceLayout.class, "sequence");
324     }
325 
326     private GroupLayout requireGroupLayout() {
327         return requireLayoutType(GroupLayout.class, "group");
328     }
329 
330     private <T extends MemoryLayout> T requireLayoutType(Class<T> layoutClass, String name) {
331         if (!layoutClass.isAssignableFrom(layout.getClass())) {
332             throw badLayoutPath(
333                     String.format("attempting to select a %s element from a non-%s layout: %s",
334                             name, name, breadcrumbs()));
335         }
336         return layoutClass.cast(layout);
337     }
338 
339     private void checkSequenceBounds(SequenceLayout seq, long index) {
340         if (index >= seq.elementCount()) {
341             throw badLayoutPath(String.format("sequence index out of bounds; index: %d, elementCount is %d for layout %s",
342                     index, seq.elementCount(), breadcrumbs()));
343         }
344     }
345 
346     private static IllegalArgumentException badLayoutPath(String cause) {
347         return new IllegalArgumentException("Bad layout path: " + cause);
348     }
349 
350     private long[] addStride(long stride) {
351         long[] newStrides = Arrays.copyOf(strides, strides.length + 1);
352         newStrides[strides.length] = stride;
353         return newStrides;
354     }
355 
356     private long[] addBound(long maxIndex) {
357         long[] newBounds = Arrays.copyOf(bounds, bounds.length + 1);
358         newBounds[bounds.length] = maxIndex;
359         return newBounds;
360     }
361 
362     private String breadcrumbs() {
363         return Stream.iterate(this, Objects::nonNull, lp -> lp.enclosing)
364                 .map(LayoutPath::layout)
365                 .map(Object::toString)
366                 .collect(joining(", selected from: "));
367     }
368 
369     /**
370      * This class provides an immutable implementation for the {@code PathElement} interface. A path element implementation
371      * is simply a pointer to one of the selector methods provided by the {@code LayoutPath} class.
372      */
373     public static final class PathElementImpl implements MemoryLayout.PathElement, UnaryOperator<LayoutPath> {
374 
375         public enum PathKind {
376             SEQUENCE_ELEMENT("unbound sequence element"),
377             SEQUENCE_ELEMENT_INDEX("bound sequence element"),
378             SEQUENCE_RANGE("sequence range"),
379             GROUP_ELEMENT("group element"),
380             DEREF_ELEMENT("dereference element");
381 
382             final String description;
383 
384             PathKind(String description) {
385                 this.description = description;
386             }
387 
388             public String description() {
389                 return description;
390             }
391         }
392 
393         final PathKind kind;
394         final UnaryOperator<LayoutPath> pathOp;
395 
396         public PathElementImpl(PathKind kind, UnaryOperator<LayoutPath> pathOp) {
397             this.kind = kind;
398             this.pathOp = pathOp;
399         }
400 
401         @Override
402         public LayoutPath apply(LayoutPath layoutPath) {
403             return pathOp.apply(layoutPath);
404         }
405 
406         public PathKind kind() {
407             return kind;
408         }
409     }
410 }