1 /*
  2  * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.internal.foreign.abi.x64.sysv;
 27 
 28 import jdk.internal.foreign.Utils;
 29 
 30 import java.lang.foreign.GroupLayout;
 31 import java.lang.foreign.MemoryLayout;
 32 import java.lang.foreign.MemorySegment;
 33 import java.lang.foreign.PaddingLayout;
 34 import java.lang.foreign.SequenceLayout;
 35 import java.lang.foreign.StructLayout;
 36 import java.lang.foreign.ValueLayout;
 37 import java.util.ArrayList;
 38 import java.util.List;
 39 import java.util.stream.Collectors;
 40 import java.util.stream.IntStream;
 41 
 42 class TypeClass {
 43     enum Kind {
 44         STRUCT,
 45         POINTER,
 46         INTEGER,
 47         FLOAT
 48     }
 49 
 50     private final Kind kind;
 51     final List<ArgumentClassImpl> classes;
 52 
 53     private TypeClass(Kind kind, List<ArgumentClassImpl> classes) {
 54         this.kind = kind;
 55         this.classes = classes;
 56     }
 57 
 58     public static TypeClass ofValue(ValueLayout layout) {
 59         final Kind kind;
 60         ArgumentClassImpl argClass = argumentClassFor(layout);
 61         kind = switch (argClass) {
 62             case POINTER -> Kind.POINTER;
 63             case INTEGER -> Kind.INTEGER;
 64             case SSE -> Kind.FLOAT;
 65             default -> throw new IllegalStateException("Unexpected argument class: " + argClass);
 66         };
 67         return new TypeClass(kind, List.of(argClass));
 68     }
 69 
 70     public static TypeClass ofStruct(GroupLayout layout) {
 71         return new TypeClass(Kind.STRUCT, classifyStructType(layout));
 72     }
 73 
 74     boolean inMemory() {
 75         return classes.stream().anyMatch(c -> c == ArgumentClassImpl.MEMORY);
 76     }
 77 
 78     private long numClasses(ArgumentClassImpl clazz) {
 79         return classes.stream().filter(c -> c == clazz).count();
 80     }
 81 
 82     public long nIntegerRegs() {
 83         return numClasses(ArgumentClassImpl.INTEGER) + numClasses(ArgumentClassImpl.POINTER);
 84     }
 85 
 86     public long nVectorRegs() {
 87         return numClasses(ArgumentClassImpl.SSE);
 88     }
 89 
 90     public Kind kind() {
 91         return kind;
 92     }
 93 
 94     // layout classification
 95 
 96     // The AVX 512 enlightened ABI says "eight eightbytes"
 97     // Although AMD64 0.99.6 states 4 eightbytes
 98     private static final int MAX_AGGREGATE_REGS_SIZE = 8;
 99     static final List<ArgumentClassImpl> COMPLEX_X87_CLASSES = List.of(
100          ArgumentClassImpl.X87,
101          ArgumentClassImpl.X87UP,
102          ArgumentClassImpl.X87,
103          ArgumentClassImpl.X87UP
104     );
105 
106     private static List<ArgumentClassImpl> createMemoryClassArray(long size) {
107         return IntStream.range(0, (int)size)
108                 .mapToObj(i -> ArgumentClassImpl.MEMORY)
109                 .collect(Collectors.toCollection(ArrayList::new));
110     }
111 
112     private static ArgumentClassImpl argumentClassFor(ValueLayout layout) {
113         Class<?> carrier = layout.carrier();
114         if (carrier == boolean.class || carrier == byte.class || carrier == char.class ||
115                 carrier == short.class || carrier == int.class || carrier == long.class) {
116             return ArgumentClassImpl.INTEGER;
117         } else if (carrier == float.class || carrier == double.class) {
118             return ArgumentClassImpl.SSE;
119         } else if (carrier == MemorySegment.class) {
120             return ArgumentClassImpl.POINTER;
121         } else {
122             throw new IllegalStateException("Cannot get here: " + carrier.getName());
123         }
124     }
125 
126     // TODO: handle zero length arrays
127     private static List<ArgumentClassImpl> classifyStructType(GroupLayout type) {
128         List<ArgumentClassImpl>[] eightbytes = groupByEightBytes(type);
129         long nWords = eightbytes.length;
130         if (nWords > MAX_AGGREGATE_REGS_SIZE) {
131             return createMemoryClassArray(nWords);
132         }
133 
134         ArrayList<ArgumentClassImpl> classes = new ArrayList<>();
135 
136         for (int idx = 0; idx < nWords; idx++) {
137             List<ArgumentClassImpl> subclasses = eightbytes[idx];
138             ArgumentClassImpl result = subclasses.stream()
139                     .reduce(ArgumentClassImpl.NO_CLASS, ArgumentClassImpl::merge);
140             classes.add(result);
141         }
142 
143         for (int i = 0; i < classes.size(); i++) {
144             ArgumentClassImpl c = classes.get(i);
145 
146             if (c == ArgumentClassImpl.MEMORY) {
147                 // if any of the eightbytes was passed in memory, pass the whole thing in memory
148                 return createMemoryClassArray(classes.size());
149             }
150 
151             if (c == ArgumentClassImpl.X87UP) {
152                 if (i == 0) {
153                     throw new IllegalArgumentException("Unexpected leading X87UP class");
154                 }
155 
156                 if (classes.get(i - 1) != ArgumentClassImpl.X87) {
157                     return createMemoryClassArray(classes.size());
158                 }
159             }
160         }
161 
162         if (classes.size() > 2) {
163             if (classes.get(0) != ArgumentClassImpl.SSE) {
164                 return createMemoryClassArray(classes.size());
165             }
166 
167             for (int i = 1; i < classes.size(); i++) {
168                 if (classes.get(i) != ArgumentClassImpl.SSEUP) {
169                     return createMemoryClassArray(classes.size());
170                 }
171             }
172         }
173 
174         return classes;
175     }
176 
177     static TypeClass classifyLayout(MemoryLayout type) {
178         try {
179             if (type instanceof ValueLayout valueLayout) {
180                 return ofValue(valueLayout);
181             } else if (type instanceof GroupLayout groupLayout) {
182                 return ofStruct(groupLayout);
183             } else {
184                 throw new IllegalArgumentException("Unsupported layout: " + type);
185             }
186         } catch (UnsupportedOperationException e) {
187             System.err.println("Failed to classify layout: " + type);
188             throw e;
189         }
190     }
191 
192     private static List<ArgumentClassImpl>[] groupByEightBytes(GroupLayout group) {
193         long offset = 0L;
194         int nEightbytes;
195         try {
196             // alignUp can overflow the value, but it's okay since toIntExact still catches it
197             nEightbytes = Math.toIntExact(Utils.alignUp(group.byteSize(), 8) / 8);
198         } catch (ArithmeticException e) {
199             throw new IllegalArgumentException("GroupLayout is too large: " + group, e);
200         }
201         @SuppressWarnings({"unchecked", "rawtypes"})
202         List<ArgumentClassImpl>[] groups = new List[nEightbytes];
203         for (MemoryLayout l : group.memberLayouts()) {
204             groupByEightBytes(l, offset, groups);
205             if (group instanceof StructLayout) {
206                 offset += l.byteSize();
207             }
208         }
209         return groups;
210     }
211 
212     private static void groupByEightBytes(MemoryLayout layout,
213                                           long offset,
214                                           List<ArgumentClassImpl>[] groups) {
215         switch (layout) {
216             case GroupLayout group -> {
217                 for (MemoryLayout m : group.memberLayouts()) {
218                     groupByEightBytes(m, offset, groups);
219                     if (group instanceof StructLayout) {
220                         offset += m.byteSize();
221                     }
222                 }
223             }
224             case PaddingLayout  _ -> {
225             }
226             case SequenceLayout seq -> {
227                 MemoryLayout elem = seq.elementLayout();
228                 for (long i = 0; i < seq.elementCount(); i++) {
229                     groupByEightBytes(elem, offset, groups);
230                     offset += elem.byteSize();
231                 }
232             }
233             case ValueLayout vl -> {
234                 List<ArgumentClassImpl> layouts = groups[(int) offset / 8];
235                 if (layouts == null) {
236                     layouts = new ArrayList<>();
237                     groups[(int) offset / 8] = layouts;
238                 }
239                 // if the aggregate contains unaligned fields, it has class MEMORY
240                 ArgumentClassImpl argumentClass = (offset % vl.byteAlignment()) == 0 ?
241                         argumentClassFor(vl) :
242                         ArgumentClassImpl.MEMORY;
243                 layouts.add(argumentClass);
244             }
245             case null, default -> throw new IllegalStateException("Unexpected layout: " + layout);
246         }
247     }
248 }