1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.buffer;
 26 
 27 import hat.Accelerator;
 28 import hat.BufferTagger;
 29 import hat.callgraph.KernelCallGraph;
 30 import hat.ifacemapper.Schema;
 31 
 32 import java.lang.annotation.Annotation;
 33 import java.lang.foreign.MemorySegment;
 34 import java.lang.foreign.ValueLayout;
 35 import java.nio.ByteOrder;
 36 import java.util.List;
 37 
 38 import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE;
 39 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 40 
 41 public interface ArgArray extends Buffer {
 42     interface Arg extends Buffer.Struct{
 43         interface Value extends Buffer.Union{
 44             interface Buf extends Buffer.Struct{
 45                  byte UNKNOWN_BYTE=(byte)0;
 46                  byte RO_BYTE =(byte)1<<1;
 47                  byte WO_BYTE =(byte)1<<2;
 48                  byte RW_BYTE =RO_BYTE|WO_BYTE;
 49                 MemorySegment address();
 50                 void address(MemorySegment address);
 51                 long bytes();
 52                 void bytes(long bytes);
 53                 byte access();
 54                 void access(byte access);
 55             }
 56             boolean z1();
 57 
 58             void z1(boolean z1);
 59 
 60             byte s8();
 61 
 62             void s8(byte s8);
 63 
 64             char u16();
 65 
 66             void u16(char u16);
 67 
 68             short s16();
 69 
 70             void s16(short s16);
 71 
 72             int u32();
 73 
 74             void u32(int u32);
 75 
 76             int s32();
 77 
 78             void s32(int s32);
 79 
 80             float f32();
 81 
 82             void f32(float f32);
 83 
 84             long u64();
 85 
 86             void u64(long u64);
 87 
 88             long s64();
 89 
 90             void s64(long s64);
 91 
 92             double f64();
 93 
 94             void f64(double f64);
 95 
 96             Buf buf();
 97         }
 98 
 99         int idx();
100         void idx(int idx);
101 
102         byte variant();
103         void variant(byte variant);
104 
105         Value value();
106 
107         default String asString() {
108             switch (variant()) {
109                 case '&':
110                     return Long.toHexString(u64());
111                 case 'F':
112                     return Float.toString(f32());
113                 case 'I':
114                     return Integer.toString(s32());
115                 case 'J':
116                     return Long.toString(s64());
117                 case 'D':
118                     return Double.toString(f64());
119                 case 'Z':
120                     return Boolean.toString(z1());
121                 case 'B':
122                     return Byte.toString(s8());
123                 case 'S':
124                     return Short.toString(s16());
125                 case 'C':
126                     return Character.toString(u16());
127             }
128             throw new IllegalStateException("what is this");
129         }
130 
131         default boolean z1() {
132             return value().z1();
133         }
134 
135         default void z1(boolean z1) {
136             variant((byte) 'Z');
137             value().z1(z1);
138         }
139 
140         default byte s8() {
141             return value().s8();
142         }
143 
144         default void s8(byte s8) {
145             variant((byte) 'B');
146             value().s8(s8);
147         }
148 
149         default char u16() {
150             return value().u16();
151         }
152 
153         default void u16(char u16) {
154             variant((byte) 'C');
155             value().u16(u16);
156         }
157 
158         default short s16() {
159             return value().s16();
160         }
161 
162         default void s16(short s16) {
163             variant((byte) 'S');
164             value().s16(s16);
165         }
166 
167         default int s32() {
168             return value().s32();
169         }
170 
171         default void s32(int s32) {
172             variant((byte) 'I');
173             value().s32(s32);
174         }
175 
176         default float f32() {
177             return value().f32();
178         }
179 
180         default void f32(float f32) {
181             variant((byte) 'F');
182             value().f32(f32);
183         }
184 
185         default long s64() {
186             return value().s64();
187         }
188 
189         default void s64(long s64) {
190             variant((byte) 'J');
191             value().s64(s64);
192         }
193 
194         default long u64() {
195             return value().u64();
196         }
197 
198         default void u64(long u64) {
199             variant((byte) '&');
200             value().u64(u64);
201         }
202 
203         default double f64() {
204             return value().f64();
205         }
206 
207         default void f64(double f64) {
208             variant((byte) 'D');
209             value().f64(f64);
210         }
211     }
212 
213     int argc();
214 
215 
216     Arg arg(long idx);
217 
218     int schemaLen();
219 
220     byte schemaBytes(long idx);
221     void schemaBytes(long idx, byte b);
222 
223     Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s
224             .arrayLen("argc").pad(12).array("arg", arg->arg
225                             .fields("idx", "variant")
226                             .pad(11/*(int)(16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize())*/)
227                             .field("value", val->val
228                                             .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64")
229                                                     .field("buf", buf->buf
230                                                             .fields("address","bytes",/*"vendorPtr",*/"access")
231                                                             .pad((int)(16 - JAVA_BYTE.byteSize()/* - JAVA_BYTE.byteSize()*/))
232                                                     )
233                             )
234                     )
235          //   .field("vendorPtr")
236             .arrayLen("schemaLen").array("schemaBytes")
237     );
238 
239     static String valueLayoutToSchemaString(ValueLayout valueLayout) {
240         String descriptor = valueLayout.carrier().descriptorString();
241         String schema = switch (descriptor) {
242             case "Z" -> "Z";
243             case "B" -> "S";
244             case "C" -> "U";
245             case "S" -> "S";
246             case "I" -> "S";
247             case "F" -> "F";
248             case "D" -> "F";
249             case "J" -> "S";
250             default -> throw new IllegalStateException("Unexpected value: " + descriptor);
251         } + valueLayout.byteSize() * 8;
252         return (valueLayout.order().equals(ByteOrder.LITTLE_ENDIAN)) ? schema.toLowerCase() : schema;
253     }
254 
255     static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph, Object... args) {
256         String[] schemas = new String[args.length];
257         StringBuilder argSchema = new StringBuilder();
258         argSchema.append(args.length);
259         for (int i = 0; i < args.length; i++) {
260             Object argObject = args[i];
261             schemas[i] = switch (argObject) {
262                 case Boolean z1 -> "(?:z1)";
263                 case Byte s8 -> "(?:s8)";
264                 case Short s16 -> "(?:s16)";
265                 case Character u16 -> "(?:u16)";
266                 case Float f32 -> "(?:f32)";
267                 case Integer s32 -> "(?:s32)";
268                 case Long s64 -> "(?:s64)";
269                 case Double f64 -> "(?:f64)";
270                 case Buffer buffer -> "(?:" +SchemaBuilder.schema(buffer)+")";
271                 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
272             };
273             if (i > 0) {
274                 argSchema.append(",");
275             }
276             argSchema.append(schemas[i]);
277         }
278         String schemaStr = argSchema.toString();
279         ArgArray argArray = schema.allocate(accelerator,args.length,schemaStr.length() + 1);
280         byte[] schemaStrBytes = schemaStr.getBytes();
281         for (int i = 0; i < schemaStrBytes.length; i++) {
282             argArray.schemaBytes(i, schemaStrBytes[i]);
283         }
284         argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
285         update(argArray,kernelCallGraph,args);
286         return argArray;
287     }
288 
289     static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
290         Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations();
291         List<BufferTagger.AccessType> bufferAccessList = kernelCallGraph.bufferAccessList;
292         for (int i = 0; i < args.length; i++) {
293             Object argObject = args[i];
294             Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
295             arg.idx(i);
296             switch (argObject) {
297                 case Boolean z1 -> arg.z1(z1);
298                 case Byte s8 -> arg.s8(s8);
299                 case Short s16 -> arg.s16(s16);
300                 case Character u16 -> arg.u16(u16);
301                 case Float f32 -> arg.f32(f32);
302                 case Integer s32 -> arg.s32(s32);
303                 case Long s64 -> arg.s64(s64);
304                 case Double f64 -> arg.f64(f64);
305                 case Buffer buffer -> {
306                     Annotation[] annotations = parameterAnnotations[i];
307                     byte accessByte = UNKNOWN_BYTE;
308                     if (annotations.length > 0) {
309                         for (Annotation annotation : annotations) {
310                             accessByte = switch (annotation) {
311                                 case RO ro-> Arg.Value.Buf.RO_BYTE;
312                                 case RW rw -> Arg.Value.Buf.RW_BYTE;
313                                 case WO wo -> Arg.Value.Buf.WO_BYTE;
314                                 default -> throw new IllegalStateException("Unexpected value: " + annotation);
315                             };
316                         }
317                     }else{
318                         throw new IllegalArgumentException("Argument " + i + " has no access annotations");
319                     }
320                     MemorySegment segment = Buffer.getMemorySegment(buffer);
321                     arg.variant((byte) '&');
322                     Arg.Value value = arg.value();
323                     Arg.Value.Buf buf = value.buf();
324                     buf.address(segment);
325                     buf.bytes(segment.byteSize());
326                     buf.access(accessByte);
327 
328                     assert bufferAccessList.get(i).value == accessByte: "buffer tagging mismatch: "
329                             + kernelCallGraph.entrypoint.getMethod().getParameters()[i].toString()
330                             + " in " + kernelCallGraph.entrypoint.getMethod().getName()
331                             + " annotated as " + BufferTagger.convertAccessType(accessByte)
332                             + " but tagged as " + bufferAccessList.get(i).name();
333                 }
334                 default -> throw new IllegalStateException("Unexpected value: " + argObject);
335             }
336         }
337     }
338 
339 
340     default String getSchemaBytes() {
341         byte[] bytes = new byte[schemaLen() + 1];
342         for (int i = 0; i < schemaLen(); i++) {
343             bytes[i] = schemaBytes(i);
344         }
345         bytes[bytes.length - 1] = '0';
346         return new String(bytes);
347     }
348 
349 
350     default String dump() {
351         StringBuilder dump = new StringBuilder();
352         dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
353         for (int argIndex = 0; argIndex < argc(); argIndex++) {
354             Arg arg = arg(argIndex);
355             dump.append(arg.asString()).append("\n");
356         }
357         return dump.toString();
358     }
359 
360 }