1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.buffer;
 26 
 27 import optkl.ifacemapper.AccessType;
 28 import hat.callgraph.KernelCallGraph;
 29 import optkl.ifacemapper.BoundSchema;
 30 import optkl.util.carriers.ArenaAndLookupCarrier;
 31 import optkl.ifacemapper.Buffer;
 32 import optkl.ifacemapper.MappableIface;
 33 import optkl.ifacemapper.Schema;
 34 import optkl.ifacemapper.SchemaBuilder;
 35 
 36 import java.lang.annotation.Annotation;
 37 import java.lang.foreign.MemorySegment;
 38 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
 39 import static java.lang.foreign.ValueLayout.JAVA_INT;
 40 
 41 public interface ArgArray extends Buffer {
 42     interface Arg extends Buffer.Struct{
 43         interface Value extends Buffer.Union{
 44             interface Buf extends Buffer.Struct{
 45                 MemorySegment address();
 46                 void address(MemorySegment address);
 47                 long bytes();
 48                 void bytes(long bytes);
 49                 byte access();
 50                 void access(byte access);
 51             }
 52             boolean z1();
 53 
 54             void z1(boolean z1);
 55 
 56             byte s8();
 57 
 58             void s8(byte s8);
 59 
 60             char u16();
 61 
 62             void u16(char u16);
 63 
 64             short s16();
 65 
 66             void s16(short s16);
 67 
 68             int u32();
 69 
 70             void u32(int u32);
 71 
 72             int s32();
 73 
 74             void s32(int s32);
 75 
 76             float f32();
 77 
 78             void f32(float f32);
 79 
 80             long u64();
 81 
 82             void u64(long u64);
 83 
 84             long s64();
 85 
 86             void s64(long s64);
 87 
 88             double f64();
 89 
 90             void f64(double f64);
 91 
 92             Buf buf();
 93         }
 94 
 95         int idx();
 96         void idx(int idx);
 97 
 98         byte variant();
 99         void variant(byte variant);
100 
101         Value value();
102 
103         default String asString() {
104             return switch (variant()) {
105                 case '&'-> Long.toHexString(u64());
106                 case 'F'-> Float.toString(f32());
107                 case 'I'-> Integer.toString(s32());
108                 case 'J'-> Long.toString(s64());
109                 case 'D'-> Double.toString(f64());
110                 case 'Z'-> Boolean.toString(z1());
111                 case 'B'-> Byte.toString(s8());
112                 case 'S'-> Short.toString(s16());
113                 case 'C'-> Character.toString(u16());
114                 default-> throw new IllegalStateException("what is this");
115             };
116         }
117 
118         default boolean z1() {
119             return value().z1();
120         }
121 
122         default void z1(boolean z1) {
123             variant((byte) 'Z');
124             value().z1(z1);
125         }
126 
127         default byte s8() {
128             return value().s8();
129         }
130 
131         default void s8(byte s8) {
132             variant((byte) 'B');
133             value().s8(s8);
134         }
135 
136         default char u16() {
137             return value().u16();
138         }
139 
140         default void u16(char u16) {
141             variant((byte) 'C');
142             value().u16(u16);
143         }
144 
145         default short s16() {
146             return value().s16();
147         }
148 
149         default void s16(short s16) {
150             variant((byte) 'S');
151             value().s16(s16);
152         }
153 
154         default int s32() {
155             return value().s32();
156         }
157 
158         default void s32(int s32) {
159             variant((byte) 'I');
160             value().s32(s32);
161         }
162 
163         default float f32() {
164             return value().f32();
165         }
166 
167         default void f32(float f32) {
168             variant((byte) 'F');
169             value().f32(f32);
170         }
171 
172         default long s64() {
173             return value().s64();
174         }
175 
176         default void s64(long s64) {
177             variant((byte) 'J');
178             value().s64(s64);
179         }
180 
181         default long u64() {
182             return value().u64();
183         }
184 
185         default void u64(long u64) {
186             variant((byte) '&');
187             value().u64(u64);
188         }
189 
190         default double f64() {
191             return value().f64();
192         }
193 
194         default void f64(double f64) {
195             variant((byte) 'D');
196             value().f64(f64);
197         }
198     }
199 
200     int argc();
201 
202 
203     Arg arg(long idx);
204 
205     int schemaLen();
206 
207     byte schemaBytes(long idx);
208     void schemaBytes(long idx, byte b);
209 
210 
211     Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s
212             .arrayLen("argc")
213             .pad((int)(16-JAVA_INT.byteSize()))
214             .array("arg", arg->arg
215                     .fields("idx", "variant")
216                     .pad((int)(16-JAVA_INT.byteSize()-JAVA_BYTE.byteSize()))
217                     .field("value", val->val
218                             .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64")
219                             .field("buf", buf->buf
220                                     .fields("address","bytes","access")
221                                     .pad((int)(16 - JAVA_BYTE.byteSize()))
222                             )
223                     )
224             )
225             .arrayLen("schemaLen")
226             .array("schemaBytes")
227     );
228 
229 
230     static ArgArray create(ArenaAndLookupCarrier cc, KernelCallGraph kernelCallGraph, Object... args) {
231         String[] schemas = new String[args.length];
232         StringBuilder argSchema = new StringBuilder();
233         argSchema.append(args.length);
234         for (int i = 0; i < args.length; i++) {
235             Object argObject = args[i];
236             schemas[i] = switch (argObject) {
237                 case Boolean z1 -> "(?:z1)";
238                 case Byte s8 -> "(?:s8)";
239                 case Short s16 -> "(?:s16)";
240                 case Character u16 -> "(?:u16)";
241                 case Float f32 -> "(?:f32)";
242                 case Integer s32 -> "(?:s32)";
243                 case Long s64 -> "(?:s64)";
244                 case Double f64 -> "(?:f64)";
245                 case Buffer buffer -> "(?:" + SchemaBuilder.schema(buffer)+")";
246                 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
247             };
248             if (i > 0) {
249                 argSchema.append(",");
250             }
251             argSchema.append(schemas[i]);
252         }
253         String schemaStr = argSchema.toString();
254         ArgArray argArray = BoundSchema.of(cc ,schema,args.length,schemaStr.length()+1).allocate();
255         byte[] schemaStrBytes = schemaStr.getBytes();
256         for (int i = 0; i < schemaStrBytes.length; i++) {
257             argArray.schemaBytes(i, schemaStrBytes[i]);
258         }
259         argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
260         update(argArray,kernelCallGraph,args);
261         return argArray;
262     }
263 
264     static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
265         Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.method().getParameterAnnotations();
266         var bufferAccessList = kernelCallGraph.state.bufferAccessList;
267         for (int i = 0; i < args.length; i++) {
268             Object argObject = args[i];
269             Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
270             arg.idx(i);
271             switch (argObject) {
272                 case Boolean z1 -> arg.z1(z1);
273                 case Byte s8 -> arg.s8(s8);
274                 case Short s16 -> arg.s16(s16);
275                 case Character u16 -> arg.u16(u16);
276                 case Float f32 -> arg.f32(f32);
277                 case Integer s32 -> arg.s32(s32);
278                 case Long s64 -> arg.s64(s64);
279                 case Double f64 -> arg.f64(f64);
280                 case Buffer buffer -> {
281                     Annotation[] annotations = parameterAnnotations[i];
282                     AccessType accessType = AccessType.NA;
283                     if (annotations.length > 0) {
284                         for (Annotation annotation : annotations) {
285                             accessType = AccessType.of(annotation);
286                         }
287                     } else {
288                         throw new IllegalArgumentException("Argument " + i + " has no access annotations");
289                     }
290                     MemorySegment segment = MappableIface.getMemorySegment(buffer);
291                     arg.variant((byte) '&');
292                     Arg.Value value = arg.value();
293                     Arg.Value.Buf buf = value.buf();
294                     buf.address(segment);
295                     buf.bytes(segment.byteSize());
296                     buf.access(bufferAccessList.get(i).value); // buf.access(accessType.value);
297                     assert bufferAccessList.get(i).value == accessType.value :
298                             (kernelCallGraph.entrypoint.method().getParameters()[i].toString()
299                                     + " in " + kernelCallGraph.entrypoint.method().getName()
300                                     + ": buffertagger " + bufferAccessList.get(i).value
301                                     + " doesn't match " + accessType.value);
302                 }
303                 default -> throw new IllegalStateException("Unexpected value: " + argObject);
304             }
305         }
306     }
307 
308 
309     default String getSchemaBytes() {
310         byte[] bytes = new byte[schemaLen() + 1];
311         for (int i = 0; i < schemaLen(); i++) {
312             bytes[i] = schemaBytes(i);
313         }
314         bytes[bytes.length - 1] = '0';
315         return new String(bytes);
316     }
317 
318 
319     default String dump() {
320         StringBuilder dump = new StringBuilder();
321         dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
322         for (int argIndex = 0; argIndex < argc(); argIndex++) {
323             Arg arg = arg(argIndex);
324             dump.append(arg.asString()).append("\n");
325         }
326         return dump.toString();
327     }
328 
329 }