1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.buffer;
 26 
 27 import hat.Accelerator;
 28 import hat.ComputeContext;
 29 import hat.callgraph.KernelCallGraph;
 30 import hat.ifacemapper.Schema;
 31 
 32 import java.lang.annotation.Annotation;
 33 import java.lang.foreign.MemorySegment;
 34 import java.lang.foreign.ValueLayout;
 35 import java.lang.invoke.MethodHandles;
 36 import java.nio.ByteOrder;
 37 
 38 import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE;
 39 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 40 import static java.lang.foreign.ValueLayout.JAVA_INT;
 41 
 42 public interface ArgArray extends Buffer {
 43     interface Arg extends Buffer.Struct{
 44         interface Value extends Buffer.Union{
 45             interface Buf extends Buffer.Struct{
 46                  byte UNKNOWN_BYTE=(byte)0;
 47                  byte RO_BYTE =(byte)1<<1;
 48                  byte WO_BYTE =(byte)1<<2;
 49                  byte RW_BYTE =RO_BYTE|WO_BYTE;
 50                 MemorySegment address();
 51                 void address(MemorySegment address);
 52                 long bytes();
 53                 void bytes(long bytes);
 54                 byte access();
 55                 void access(byte access);
 56             }
 57             boolean z1();
 58 
 59             void z1(boolean z1);
 60 
 61             byte s8();
 62 
 63             void s8(byte s8);
 64 
 65             char u16();
 66 
 67             void u16(char u16);
 68 
 69             short s16();
 70 
 71             void s16(short s16);
 72 
 73             int u32();
 74 
 75             void u32(int u32);
 76 
 77             int s32();
 78 
 79             void s32(int s32);
 80 
 81             float f32();
 82 
 83             void f32(float f32);
 84 
 85             long u64();
 86 
 87             void u64(long u64);
 88 
 89             long s64();
 90 
 91             void s64(long s64);
 92 
 93             double f64();
 94 
 95             void f64(double f64);
 96 
 97             Buf buf();
 98         }
 99 
100         int idx();
101         void idx(int idx);
102 
103         byte variant();
104         void variant(byte variant);
105 
106         Value value();
107 
108         default String asString() {
109             switch (variant()) {
110                 case '&':
111                     return Long.toHexString(u64());
112                 case 'F':
113                     return Float.toString(f32());
114                 case 'I':
115                     return Integer.toString(s32());
116                 case 'J':
117                     return Long.toString(s64());
118                 case 'D':
119                     return Double.toString(f64());
120                 case 'Z':
121                     return Boolean.toString(z1());
122                 case 'B':
123                     return Byte.toString(s8());
124                 case 'S':
125                     return Short.toString(s16());
126                 case 'C':
127                     return Character.toString(u16());
128             }
129             throw new IllegalStateException("what is this");
130         }
131 
132         default boolean z1() {
133             return value().z1();
134         }
135 
136         default void z1(boolean z1) {
137             variant((byte) 'Z');
138             value().z1(z1);
139         }
140 
141         default byte s8() {
142             return value().s8();
143         }
144 
145         default void s8(byte s8) {
146             variant((byte) 'B');
147             value().s8(s8);
148         }
149 
150         default char u16() {
151             return value().u16();
152         }
153 
154         default void u16(char u16) {
155             variant((byte) 'C');
156             value().u16(u16);
157         }
158 
159         default short s16() {
160             return value().s16();
161         }
162 
163         default void s16(short s16) {
164             variant((byte) 'S');
165             value().s16(s16);
166         }
167 
168         default int s32() {
169             return value().s32();
170         }
171 
172         default void s32(int s32) {
173             variant((byte) 'I');
174             value().s32(s32);
175         }
176 
177         default float f32() {
178             return value().f32();
179         }
180 
181         default void f32(float f32) {
182             variant((byte) 'F');
183             value().f32(f32);
184         }
185 
186         default long s64() {
187             return value().s64();
188         }
189 
190         default void s64(long s64) {
191             variant((byte) 'J');
192             value().s64(s64);
193         }
194 
195         default long u64() {
196             return value().u64();
197         }
198 
199         default void u64(long u64) {
200             variant((byte) '&');
201             value().u64(u64);
202         }
203 
204         default double f64() {
205             return value().f64();
206         }
207 
208         default void f64(double f64) {
209             variant((byte) 'D');
210             value().f64(f64);
211         }
212     }
213 
214     int argc();
215 
216 
217     Arg arg(long idx);
218 
219     int schemaLen();
220 
221     byte schemaBytes(long idx);
222     void schemaBytes(long idx, byte b);
223 
224     Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s
225             .arrayLen("argc").pad(12).array("arg", arg->arg
226                             .fields("idx", "variant")
227                             .pad(11/*(int)(16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize())*/)
228                             .field("value", value->value
229                                             .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64")
230                                                     .field("buf", buf->buf
231                                                             .fields("address","bytes",/*"vendorPtr",*/"access")
232                                                             .pad((int)(16 - JAVA_BYTE.byteSize()/* - JAVA_BYTE.byteSize()*/))
233                                                     )
234                             )
235                     )
236          //   .field("vendorPtr")
237             .arrayLen("schemaLen").array("schemaBytes")
238     );
239 
240     static String valueLayoutToSchemaString(ValueLayout valueLayout) {
241         String descriptor = valueLayout.carrier().descriptorString();
242         String schema = switch (descriptor) {
243             case "Z" -> "Z";
244             case "B" -> "S";
245             case "C" -> "U";
246             case "S" -> "S";
247             case "I" -> "S";
248             case "F" -> "F";
249             case "D" -> "F";
250             case "J" -> "S";
251             default -> throw new IllegalStateException("Unexpected value: " + descriptor);
252         } + valueLayout.byteSize() * 8;
253         return (valueLayout.order().equals(ByteOrder.LITTLE_ENDIAN)) ? schema.toLowerCase() : schema;
254     }
255 
256     static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph, Object... args) {
257         String[] schemas = new String[args.length];
258         StringBuilder argSchema = new StringBuilder();
259         argSchema.append(args.length);
260         for (int i = 0; i < args.length; i++) {
261             Object argObject = args[i];
262             schemas[i] = switch (argObject) {
263                 case Boolean z1 -> "(?:z1)";
264                 case Byte s8 -> "(?:s8)";
265                 case Short s16 -> "(?:s16)";
266                 case Character u16 -> "(?:u16)";
267                 case Float f32 -> "(?:f32)";
268                 case Integer s32 -> "(?:s32)";
269                 case Long s64 -> "(?:s64)";
270                 case Double f64 -> "(?:f64)";
271                 case Buffer buffer -> "(?:" +SchemaBuilder.schema(buffer)+")";
272                 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
273             };
274             if (i > 0) {
275                 argSchema.append(",");
276             }
277             argSchema.append(schemas[i]);
278         }
279         String schemaStr = argSchema.toString();
280         ArgArray argArray = schema.allocate(accelerator,args.length,schemaStr.length() + 1);
281         byte[] schemaStrBytes = schemaStr.getBytes();
282         for (int i = 0; i < schemaStrBytes.length; i++) {
283             argArray.schemaBytes(i, schemaStrBytes[i]);
284         }
285         argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
286         update(argArray,kernelCallGraph,args);
287         return argArray;
288     }
289 
290     static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
291         Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations();
292         for (int i = 0; i < args.length; i++) {
293             Object argObject = args[i];
294             Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
295             arg.idx(i);
296             switch (argObject) {
297                 case Boolean z1 -> arg.z1(z1);
298                 case Byte s8 -> arg.s8(s8);
299                 case Short s16 -> arg.s16(s16);
300                 case Character u16 -> arg.u16(u16);
301                 case Float f32 -> arg.f32(f32);
302                 case Integer s32 -> arg.s32(s32);
303                 case Long s64 -> arg.s64(s64);
304                 case Double f64 -> arg.f64(f64);
305                 case Buffer buffer -> {
306                     Annotation[] annotations = parameterAnnotations[i];
307                     byte accessByte = UNKNOWN_BYTE;
308                     if (annotations.length > 0) {
309                         for (Annotation annotation : annotations) {
310                             accessByte = switch (annotation) {
311                                 case RO ro-> Arg.Value.Buf.RO_BYTE;
312                                 case RW rw -> Arg.Value.Buf.RW_BYTE;
313                                 case WO wo -> Arg.Value.Buf.WO_BYTE;
314                                 default -> throw new IllegalStateException("Unexpected value: " + annotation);
315                             };
316                         }
317                     }else{
318                         throw new IllegalArgumentException("Argument " + i + " has no access annotations");
319                     }
320                     MemorySegment segment = Buffer.getMemorySegment(buffer);
321                     arg.variant((byte) '&');
322                     Arg.Value value = arg.value();
323                     Arg.Value.Buf buf = value.buf();
324                     buf.address(segment);
325                     buf.bytes(segment.byteSize());
326                     buf.access(accessByte);
327                 }
328                 default -> throw new IllegalStateException("Unexpected value: " + argObject);
329             }
330         }
331     }
332 
333 
334     default String getSchemaBytes() {
335         byte[] bytes = new byte[schemaLen() + 1];
336         for (int i = 0; i < schemaLen(); i++) {
337             bytes[i] = schemaBytes(i);
338         }
339         bytes[bytes.length - 1] = '0';
340         return new String(bytes);
341     }
342 
343 
344     default String dump() {
345         StringBuilder dump = new StringBuilder();
346         dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
347         for (int argIndex = 0; argIndex < argc(); argIndex++) {
348             Arg arg = arg(argIndex);
349             dump.append(arg.asString()).append("\n");
350         }
351         return dump.toString();
352     }
353 
354 }