1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.buffer;
 26 
 27 import hat.annotations.Kernel;
 28 import hat.callgraph.KernelCallGraph;
 29 import optkl.ifacemapper.AccessType;
 30 import optkl.ifacemapper.BoundSchema;
 31 import optkl.ifacemapper.Buffer;
 32 import optkl.ifacemapper.Schema;
 33 import optkl.ifacemapper.SchemaBuilder;
 34 import optkl.util.carriers.ArenaAndLookupCarrier;
 35 
 36 import java.lang.annotation.Annotation;
 37 import java.lang.foreign.MemorySegment;
 38 
 39 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 40 import static java.lang.foreign.ValueLayout.JAVA_INT;
 41 import static optkl.ifacemapper.MappableIface.getMemorySegment;
 42 
 43 public interface ArgArray extends Buffer {
 44     interface Arg extends Buffer.Struct {
 45         interface Value extends Buffer.Union {
 46             interface Buf extends Buffer.Struct {
 47                 MemorySegment address();
 48 
 49                 void address(MemorySegment address);
 50 
 51                 long bytes();
 52 
 53                 void bytes(long bytes);
 54 
 55                 byte access();
 56 
 57                 void access(byte access);
 58             }
 59 
 60             boolean z1();
 61 
 62             void z1(boolean z1);
 63 
 64             byte s8();
 65 
 66             void s8(byte s8);
 67 
 68             char u16();
 69 
 70             void u16(char u16);
 71 
 72             short s16();
 73 
 74             void s16(short s16);
 75 
 76             int u32();
 77 
 78             void u32(int u32);
 79 
 80             int s32();
 81 
 82             void s32(int s32);
 83 
 84             float f32();
 85 
 86             void f32(float f32);
 87 
 88             long u64();
 89 
 90             void u64(long u64);
 91 
 92             long s64();
 93 
 94             void s64(long s64);
 95 
 96             double f64();
 97 
 98             void f64(double f64);
 99 
100             Buf buf();
101         }
102 
103         int idx();
104 
105         void idx(int idx);
106 
107         byte variant();
108 
109         void variant(byte variant);
110 
111         Value value();
112 
113         default String asString() {
114             return switch (variant()) {
115                 case '&' -> Long.toHexString(u64());
116                 case 'F' -> Float.toString(f32());
117                 case 'I' -> Integer.toString(s32());
118                 case 'J' -> Long.toString(s64());
119                 case 'D' -> Double.toString(f64());
120                 case 'Z' -> Boolean.toString(z1());
121                 case 'B' -> Byte.toString(s8());
122                 case 'S' -> Short.toString(s16());
123                 case 'C' -> Character.toString(u16());
124                 default -> throw new IllegalStateException("Variant " + variant() + " not supported for arg");
125             };
126         }
127 
128         default boolean z1() {
129             return value().z1();
130         }
131 
132         default void z1(boolean z1) {
133             variant((byte) 'Z');
134             value().z1(z1);
135         }
136 
137         default byte s8() {
138             return value().s8();
139         }
140 
141         default void s8(byte s8) {
142             variant((byte) 'B');
143             value().s8(s8);
144         }
145 
146         default char u16() {
147             return value().u16();
148         }
149 
150         default void u16(char u16) {
151             variant((byte) 'C');
152             value().u16(u16);
153         }
154 
155         default short s16() {
156             return value().s16();
157         }
158 
159         default void s16(short s16) {
160             variant((byte) 'S');
161             value().s16(s16);
162         }
163 
164         default int s32() {
165             return value().s32();
166         }
167 
168         default void s32(int s32) {
169             variant((byte) 'I');
170             value().s32(s32);
171         }
172 
173         default float f32() {
174             return value().f32();
175         }
176 
177         default void f32(float f32) {
178             variant((byte) 'F');
179             value().f32(f32);
180         }
181 
182         default long s64() {
183             return value().s64();
184         }
185 
186         default void s64(long s64) {
187             variant((byte) 'J');
188             value().s64(s64);
189         }
190 
191         default long u64() {
192             return value().u64();
193         }
194 
195         default void u64(long u64) {
196             variant((byte) '&');
197             value().u64(u64);
198         }
199 
200         default double f64() {
201             return value().f64();
202         }
203 
204         default void f64(double f64) {
205             variant((byte) 'D');
206             value().f64(f64);
207         }
208     }
209 
210     int argc();
211 
212 
213     Arg arg(long idx);
214 
215     int schemaLen();
216 
217     byte schemaBytes(long idx);
218 
219     void schemaBytes(long idx, byte b);
220 
221 
222     Schema<ArgArray> schema = Schema.of(ArgArray.class, s -> s
223             .arrayLen("argc")
224             .pad((int) (16 - JAVA_INT.byteSize()))
225             .array("arg", arg -> arg
226                     .fields("idx", "variant")
227                     .pad((int) (16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize()))
228                     .field("value", val -> val
229                             .fields("z1", "s8", "u16", "s16", "s32", "u32", "f32", "s64", "u64", "f64")
230                             .field("buf", buf -> buf
231                                     .fields("address", "bytes", "access")
232                                     .pad((int) (16 - JAVA_BYTE.byteSize()))
233                             )
234                     )
235             )
236             .arrayLen("schemaLen")
237             .array("schemaBytes")
238     );
239 
240 
241     static ArgArray create(ArenaAndLookupCarrier cc, KernelCallGraph kernelCallGraph, Object... args) {
242         String[] schemas = new String[args.length];
243         StringBuilder argSchema = new StringBuilder();
244         argSchema.append(args.length);
245         for (int i = 0; i < args.length; i++) {
246             Object argObject = args[i];
247             schemas[i] = switch (argObject) {
248                 case Boolean _ -> "(?:z1)";
249                 case Byte _ -> "(?:s8)";
250                 case Short _ -> "(?:s16)";
251                 case Character _ -> "(?:u16)";
252                 case Float _ -> "(?:f32)";
253                 case Integer _ -> "(?:s32)";
254                 case Long _ -> "(?:s64)";
255                 case Double _ -> "(?:f64)";
256                 case Buffer buffer -> "(?:" + SchemaBuilder.schema(buffer) + ")";
257                 default ->
258                         throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
259             };
260             if (i > 0) {
261                 argSchema.append(",");
262             }
263             argSchema.append(schemas[i]);
264         }
265         String schemaStr = argSchema.toString();
266         ArgArray argArray = BoundSchema.of(cc, schema, args.length, schemaStr.length() + 1).allocate();
267         byte[] schemaStrBytes = schemaStr.getBytes();
268         for (int i = 0; i < schemaStrBytes.length; i++) {
269             argArray.schemaBytes(i, schemaStrBytes[i]);
270         }
271         argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
272         update(argArray, kernelCallGraph, args);
273         return argArray;
274     }
275 
276     static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
277         Annotation[][] parameterAnnotations = kernelCallGraph.callDag.entryPoint.method().getParameterAnnotations();
278         var bufferAccessList = kernelCallGraph.bufferAccessList;
279         for (int i = 0; i < args.length; i++) {
280             Object argObject = args[i];
281             Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
282             arg.idx(i);
283             switch (argObject) {
284                 case Boolean z1 -> arg.z1(z1);
285                 case Byte s8 -> arg.s8(s8);
286                 case Short s16 -> arg.s16(s16);
287                 case Character u16 -> arg.u16(u16);
288                 case Float f32 -> arg.f32(f32);
289                 case Integer s32 -> arg.s32(s32);
290                 case Long s64 -> arg.s64(s64);
291                 case Double f64 -> arg.f64(f64);
292                 case Buffer buffer -> {
293                     Annotation[] annotations = parameterAnnotations[i];
294                     AccessType accessType = AccessType.NA;
295                     for (Annotation annotation : annotations) {
296                         accessType = AccessType.of(annotation);
297                     }
298                     MemorySegment segment = getMemorySegment(buffer);
299                     arg.variant((byte) '&');
300                     Arg.Value value = arg.value();
301                     Arg.Value.Buf buf = value.buf();
302                     buf.address(segment);
303                     buf.bytes(segment.byteSize());
304 
305                     if (kernelCallGraph.callDag.entryPoint.method().getAnnotation(Kernel.class) != null) {
306                         // If the annotation is present, then we keep the accessor defined for each parameter
307                         buf.access(accessType.value);
308                     } else {
309                         // otherwise, we rely on the buffer-tagger to set the accessor
310                         buf.access(bufferAccessList.get(i).value);
311                     }
312                 }
313                 default -> throw new IllegalStateException("Unexpected value: " + argObject);
314             }
315         }
316     }
317 
318     default String getSchemaBytes() {
319         byte[] bytes = new byte[schemaLen() + 1];
320         for (int i = 0; i < schemaLen(); i++) {
321             bytes[i] = schemaBytes(i);
322         }
323         bytes[bytes.length - 1] = '0';
324         return new String(bytes);
325     }
326 
327     default String dump() {
328         StringBuilder dump = new StringBuilder();
329         dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
330         for (int argIndex = 0; argIndex < argc(); argIndex++) {
331             Arg arg = arg(argIndex);
332             dump.append(arg.asString()).append("\n");
333         }
334         return dump.toString();
335     }
336 }