< prev index next >

src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java

Print this page
@@ -46,16 +46,16 @@
  
  import jdk.internal.access.JavaNioAccess;
  import jdk.internal.access.SharedSecrets;
  import jdk.internal.access.foreign.UnmapperProxy;
  import jdk.internal.misc.ScopedMemoryAccess;
- import jdk.internal.misc.Unsafe;
  import jdk.internal.reflect.CallerSensitive;
  import jdk.internal.reflect.Reflection;
  import jdk.internal.util.ArraysSupport;
  import jdk.internal.util.Preconditions;
  import jdk.internal.vm.annotation.ForceInline;
+ import sun.nio.ch.DirectBuffer;
  
  import static java.lang.foreign.ValueLayout.JAVA_BYTE;
  
  /**
   * This abstract class provides an immutable implementation for the {@code MemorySegment} interface. This class contains information

@@ -262,27 +262,18 @@
              final long thatStart = that.unsafeGetOffset();
              final long thisEnd = thisStart + this.byteSize();
              final long thatEnd = thatStart + that.byteSize();
  
              if (thisStart < thatEnd && thisEnd > thatStart) {  //overlap occurs
-                 long offsetToThat = this.segmentOffset(that);
+                 long offsetToThat = that.address() - this.address();
                  long newOffset = offsetToThat >= 0 ? offsetToThat : 0;
                  return Optional.of(asSlice(newOffset, Math.min(this.byteSize() - newOffset, that.byteSize() + offsetToThat)));
              }
          }
          return Optional.empty();
      }
  
-     @Override
-     public final long segmentOffset(MemorySegment other) {
-         AbstractMemorySegmentImpl that = (AbstractMemorySegmentImpl) Objects.requireNonNull(other);
-         if (unsafeGetBase() == that.unsafeGetBase()) {
-             return that.unsafeGetOffset() - this.unsafeGetOffset();
-         }
-         throw new UnsupportedOperationException("Cannot compute offset from native to heap (or vice versa).");
-     }
- 
      @Override
      public void load() {
          throw notAMappedSegment();
      }
  

@@ -544,55 +535,46 @@
          int scaleFactor = getScaleFactor(bb);
          final MemorySessionImpl bufferScope;
          if (bufferSegment != null) {
              bufferScope = bufferSegment.scope;
          } else {
-             bufferScope = MemorySessionImpl.heapSession(bb);
+             bufferScope = MemorySessionImpl.createHeap(bufferRef(bb));
          }
          if (base != null) {
-             if (base instanceof byte[]) {
-                 return new HeapMemorySegmentImpl.OfByte(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof short[]) {
-                 return new HeapMemorySegmentImpl.OfShort(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof char[]) {
-                 return new HeapMemorySegmentImpl.OfChar(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof int[]) {
-                 return new HeapMemorySegmentImpl.OfInt(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof float[]) {
-                 return new HeapMemorySegmentImpl.OfFloat(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof long[]) {
-                 return new HeapMemorySegmentImpl.OfLong(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else if (base instanceof double[]) {
-                 return new HeapMemorySegmentImpl.OfDouble(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
-             } else {
-                 throw new AssertionError("Cannot get here");
-             }
+             return switch (base) {
+                 case byte[] __ ->
+                         new HeapMemorySegmentImpl.OfByte(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case short[] __ ->
+                         new HeapMemorySegmentImpl.OfShort(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case char[] __ ->
+                         new HeapMemorySegmentImpl.OfChar(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case int[] __ ->
+                         new HeapMemorySegmentImpl.OfInt(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case float[] __ ->
+                         new HeapMemorySegmentImpl.OfFloat(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case long[] __ ->
+                         new HeapMemorySegmentImpl.OfLong(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 case double[] __ ->
+                         new HeapMemorySegmentImpl.OfDouble(bbAddress + (pos << scaleFactor), base, size << scaleFactor, readOnly, bufferScope);
+                 default -> throw new AssertionError("Cannot get here");
+             };
          } else if (unmapper == null) {
              return new NativeMemorySegmentImpl(bbAddress + (pos << scaleFactor), size << scaleFactor, readOnly, bufferScope);
          } else {
              // we can ignore scale factor here, a mapped buffer is always a byte buffer, so scaleFactor == 0.
              return new MappedMemorySegmentImpl(bbAddress + pos, unmapper, size, readOnly, bufferScope);
          }
      }
  
-     private static int getScaleFactor(Buffer buffer) {
-         if (buffer instanceof ByteBuffer) {
-             return 0;
-         } else if (buffer instanceof CharBuffer) {
-             return 1;
-         } else if (buffer instanceof ShortBuffer) {
-             return 1;
-         } else if (buffer instanceof IntBuffer) {
-             return 2;
-         } else if (buffer instanceof FloatBuffer) {
-             return 2;
-         } else if (buffer instanceof LongBuffer) {
-             return 3;
-         } else if (buffer instanceof DoubleBuffer) {
-             return 3;
+     private static Object bufferRef(Buffer buffer) {
+         if (buffer instanceof DirectBuffer directBuffer) {
+             // direct buffer, return either the buffer attachment (for slices and views), or the buffer itself
+             return directBuffer.attachment() != null ?
+                     directBuffer.attachment() : directBuffer;
          } else {
-             throw new AssertionError("Cannot get here");
+             // heap buffer, return the underlying array
+             return NIO_ACCESS.getBufferBase(buffer);
          }
      }
  
      @ForceInline
      public static void copy(MemorySegment srcSegment, ValueLayout srcElementLayout, long srcOffset,

@@ -629,60 +611,56 @@
      @ForceInline
      public static void copy(MemorySegment srcSegment, ValueLayout srcLayout, long srcOffset,
                              Object dstArray, int dstIndex,
                              int elementCount) {
  
-         long baseAndScale = getBaseAndScale(dstArray.getClass());
+         var dstInfo = Utils.BaseAndScale.of(dstArray);
          if (dstArray.getClass().componentType() != srcLayout.carrier()) {
              throw new IllegalArgumentException("Incompatible value layout: " + srcLayout);
          }
-         int dstBase = (int)baseAndScale;
-         long dstWidth = (int)(baseAndScale >> 32); // Use long arithmetics below
          AbstractMemorySegmentImpl srcImpl = (AbstractMemorySegmentImpl)srcSegment;
          Utils.checkElementAlignment(srcLayout, "Source layout alignment greater than its size");
          if (!srcImpl.isAlignedForElement(srcOffset, srcLayout)) {
              throw new IllegalArgumentException("Source segment incompatible with alignment constraints");
          }
-         srcImpl.checkAccess(srcOffset, elementCount * dstWidth, true);
+         srcImpl.checkAccess(srcOffset, elementCount * dstInfo.scale(), true);
          Objects.checkFromIndexSize(dstIndex, elementCount, Array.getLength(dstArray));
-         if (dstWidth == 1 || srcLayout.order() == ByteOrder.nativeOrder()) {
+         if (dstInfo.scale() == 1 || srcLayout.order() == ByteOrder.nativeOrder()) {
              ScopedMemoryAccess.getScopedMemoryAccess().copyMemory(srcImpl.sessionImpl(), null,
                      srcImpl.unsafeGetBase(), srcImpl.unsafeGetOffset() + srcOffset,
-                     dstArray, dstBase + (dstIndex * dstWidth), elementCount * dstWidth);
+                     dstArray, dstInfo.base() + (dstIndex * dstInfo.scale()), elementCount * dstInfo.scale());
          } else {
              ScopedMemoryAccess.getScopedMemoryAccess().copySwapMemory(srcImpl.sessionImpl(), null,
                      srcImpl.unsafeGetBase(), srcImpl.unsafeGetOffset() + srcOffset,
-                     dstArray, dstBase + (dstIndex * dstWidth), elementCount * dstWidth, dstWidth);
+                     dstArray, dstInfo.base() + (dstIndex * dstInfo.scale()), elementCount * dstInfo.scale(), dstInfo.scale());
          }
      }
  
      @ForceInline
      public static void copy(Object srcArray, int srcIndex,
                              MemorySegment dstSegment, ValueLayout dstLayout, long dstOffset,
                              int elementCount) {
  
-         long baseAndScale = getBaseAndScale(srcArray.getClass());
+         var srcInfo = Utils.BaseAndScale.of(srcArray);
          if (srcArray.getClass().componentType() != dstLayout.carrier()) {
              throw new IllegalArgumentException("Incompatible value layout: " + dstLayout);
          }
-         int srcBase = (int)baseAndScale;
-         long srcWidth = (int)(baseAndScale >> 32); // Use long arithmetics below
          Objects.checkFromIndexSize(srcIndex, elementCount, Array.getLength(srcArray));
          AbstractMemorySegmentImpl destImpl = (AbstractMemorySegmentImpl)dstSegment;
          Utils.checkElementAlignment(dstLayout, "Destination layout alignment greater than its size");
          if (!destImpl.isAlignedForElement(dstOffset, dstLayout)) {
              throw new IllegalArgumentException("Destination segment incompatible with alignment constraints");
          }
-         destImpl.checkAccess(dstOffset, elementCount * srcWidth, false);
-         if (srcWidth == 1 || dstLayout.order() == ByteOrder.nativeOrder()) {
+         destImpl.checkAccess(dstOffset, elementCount * srcInfo.scale(), false);
+         if (srcInfo.scale() == 1 || dstLayout.order() == ByteOrder.nativeOrder()) {
              ScopedMemoryAccess.getScopedMemoryAccess().copyMemory(null, destImpl.sessionImpl(),
-                     srcArray, srcBase + (srcIndex * srcWidth),
-                     destImpl.unsafeGetBase(), destImpl.unsafeGetOffset() + dstOffset, elementCount * srcWidth);
+                     srcArray, srcInfo.base() + (srcIndex * srcInfo.scale()),
+                     destImpl.unsafeGetBase(), destImpl.unsafeGetOffset() + dstOffset, elementCount * srcInfo.scale());
          } else {
              ScopedMemoryAccess.getScopedMemoryAccess().copySwapMemory(null, destImpl.sessionImpl(),
-                     srcArray, srcBase + (srcIndex * srcWidth),
-                     destImpl.unsafeGetBase(), destImpl.unsafeGetOffset() + dstOffset, elementCount * srcWidth, srcWidth);
+                     srcArray, srcInfo.base() + (srcIndex * srcInfo.scale()),
+                     destImpl.unsafeGetBase(), destImpl.unsafeGetOffset() + dstOffset, elementCount * srcInfo.scale(), srcInfo.scale());
          }
      }
  
      public static long mismatch(MemorySegment srcSegment, long srcFromOffset, long srcToOffset,
                                  MemorySegment dstSegment, long dstFromOffset, long dstToOffset) {

@@ -720,25 +698,17 @@
              }
          }
          return srcBytes != dstBytes ? bytes : -1;
      }
  
-     private static long getBaseAndScale(Class<?> arrayType) {
-         if (arrayType.equals(byte[].class)) {
-             return (long) Unsafe.ARRAY_BYTE_BASE_OFFSET | ((long)Unsafe.ARRAY_BYTE_INDEX_SCALE << 32);
-         } else if (arrayType.equals(char[].class)) {
-             return (long) Unsafe.ARRAY_CHAR_BASE_OFFSET | ((long)Unsafe.ARRAY_CHAR_INDEX_SCALE << 32);
-         } else if (arrayType.equals(short[].class)) {
-             return (long)Unsafe.ARRAY_SHORT_BASE_OFFSET | ((long)Unsafe.ARRAY_SHORT_INDEX_SCALE << 32);
-         } else if (arrayType.equals(int[].class)) {
-             return (long)Unsafe.ARRAY_INT_BASE_OFFSET | ((long) Unsafe.ARRAY_INT_INDEX_SCALE << 32);
-         } else if (arrayType.equals(float[].class)) {
-             return (long)Unsafe.ARRAY_FLOAT_BASE_OFFSET | ((long)Unsafe.ARRAY_FLOAT_INDEX_SCALE << 32);
-         } else if (arrayType.equals(long[].class)) {
-             return (long)Unsafe.ARRAY_LONG_BASE_OFFSET | ((long)Unsafe.ARRAY_LONG_INDEX_SCALE << 32);
-         } else if (arrayType.equals(double[].class)) {
-             return (long)Unsafe.ARRAY_DOUBLE_BASE_OFFSET | ((long)Unsafe.ARRAY_DOUBLE_INDEX_SCALE << 32);
-         } else {
-             throw new IllegalArgumentException("Not a supported array class: " + arrayType.getSimpleName());
-         }
+     private static int getScaleFactor(Buffer buffer) {
+         return switch (buffer) {
+             case ByteBuffer   __ -> 0;
+             case CharBuffer   __ -> 1;
+             case ShortBuffer  __ -> 1;
+             case IntBuffer    __ -> 2;
+             case FloatBuffer  __ -> 2;
+             case LongBuffer   __ -> 3;
+             case DoubleBuffer __ -> 3;
+         };
      }
  }
< prev index next >