1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.buffer;
26
27 import hat.Accelerator;
28 import hat.BufferTagger;
29 import hat.callgraph.KernelCallGraph;
30 import hat.ifacemapper.Schema;
31
32 import java.lang.annotation.Annotation;
33 import java.lang.foreign.MemorySegment;
34 import java.lang.foreign.ValueLayout;
35 import java.nio.ByteOrder;
36 import java.util.List;
37
38 import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE;
39 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
40
41 public interface ArgArray extends Buffer {
42 interface Arg extends Buffer.Struct{
43 interface Value extends Buffer.Union{
44 interface Buf extends Buffer.Struct{
45 byte UNKNOWN_BYTE=(byte)0;
46 byte RO_BYTE =(byte)1<<1;
47 byte WO_BYTE =(byte)1<<2;
48 byte RW_BYTE =RO_BYTE|WO_BYTE;
49 MemorySegment address();
50 void address(MemorySegment address);
51 long bytes();
52 void bytes(long bytes);
53 byte access();
54 void access(byte access);
55 }
56 boolean z1();
57
58 void z1(boolean z1);
59
60 byte s8();
61
62 void s8(byte s8);
63
64 char u16();
65
66 void u16(char u16);
67
68 short s16();
69
70 void s16(short s16);
71
72 int u32();
73
74 void u32(int u32);
75
76 int s32();
77
78 void s32(int s32);
79
80 float f32();
81
82 void f32(float f32);
83
84 long u64();
85
86 void u64(long u64);
87
88 long s64();
89
90 void s64(long s64);
91
92 double f64();
93
94 void f64(double f64);
95
96 Buf buf();
97 }
98
99 int idx();
100 void idx(int idx);
101
102 byte variant();
103 void variant(byte variant);
104
105 Value value();
106
107 default String asString() {
108 switch (variant()) {
109 case '&':
110 return Long.toHexString(u64());
111 case 'F':
112 return Float.toString(f32());
113 case 'I':
114 return Integer.toString(s32());
115 case 'J':
116 return Long.toString(s64());
117 case 'D':
118 return Double.toString(f64());
119 case 'Z':
120 return Boolean.toString(z1());
121 case 'B':
122 return Byte.toString(s8());
123 case 'S':
124 return Short.toString(s16());
125 case 'C':
126 return Character.toString(u16());
127 }
128 throw new IllegalStateException("what is this");
129 }
130
131 default boolean z1() {
132 return value().z1();
133 }
134
135 default void z1(boolean z1) {
136 variant((byte) 'Z');
137 value().z1(z1);
138 }
139
140 default byte s8() {
141 return value().s8();
142 }
143
144 default void s8(byte s8) {
145 variant((byte) 'B');
146 value().s8(s8);
147 }
148
149 default char u16() {
150 return value().u16();
151 }
152
153 default void u16(char u16) {
154 variant((byte) 'C');
155 value().u16(u16);
156 }
157
158 default short s16() {
159 return value().s16();
160 }
161
162 default void s16(short s16) {
163 variant((byte) 'S');
164 value().s16(s16);
165 }
166
167 default int s32() {
168 return value().s32();
169 }
170
171 default void s32(int s32) {
172 variant((byte) 'I');
173 value().s32(s32);
174 }
175
176 default float f32() {
177 return value().f32();
178 }
179
180 default void f32(float f32) {
181 variant((byte) 'F');
182 value().f32(f32);
183 }
184
185 default long s64() {
186 return value().s64();
187 }
188
189 default void s64(long s64) {
190 variant((byte) 'J');
191 value().s64(s64);
192 }
193
194 default long u64() {
195 return value().u64();
196 }
197
198 default void u64(long u64) {
199 variant((byte) '&');
200 value().u64(u64);
201 }
202
203 default double f64() {
204 return value().f64();
205 }
206
207 default void f64(double f64) {
208 variant((byte) 'D');
209 value().f64(f64);
210 }
211 }
212
213 int argc();
214
215
216 Arg arg(long idx);
217
218 int schemaLen();
219
220 byte schemaBytes(long idx);
221 void schemaBytes(long idx, byte b);
222
223 Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s
224 .arrayLen("argc").pad(12).array("arg", arg->arg
225 .fields("idx", "variant")
226 .pad(11/*(int)(16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize())*/)
227 .field("value", val->val
228 .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64")
229 .field("buf", buf->buf
230 .fields("address","bytes",/*"vendorPtr",*/"access")
231 .pad((int)(16 - JAVA_BYTE.byteSize()/* - JAVA_BYTE.byteSize()*/))
232 )
233 )
234 )
235 // .field("vendorPtr")
236 .arrayLen("schemaLen").array("schemaBytes")
237 );
238
239 static String valueLayoutToSchemaString(ValueLayout valueLayout) {
240 String descriptor = valueLayout.carrier().descriptorString();
241 String schema = switch (descriptor) {
242 case "Z" -> "Z";
243 case "B" -> "S";
244 case "C" -> "U";
245 case "S" -> "S";
246 case "I" -> "S";
247 case "F" -> "F";
248 case "D" -> "F";
249 case "J" -> "S";
250 default -> throw new IllegalStateException("Unexpected value: " + descriptor);
251 } + valueLayout.byteSize() * 8;
252 return (valueLayout.order().equals(ByteOrder.LITTLE_ENDIAN)) ? schema.toLowerCase() : schema;
253 }
254
255 static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph, Object... args) {
256 String[] schemas = new String[args.length];
257 StringBuilder argSchema = new StringBuilder();
258 argSchema.append(args.length);
259 for (int i = 0; i < args.length; i++) {
260 Object argObject = args[i];
261 schemas[i] = switch (argObject) {
262 case Boolean z1 -> "(?:z1)";
263 case Byte s8 -> "(?:s8)";
264 case Short s16 -> "(?:s16)";
265 case Character u16 -> "(?:u16)";
266 case Float f32 -> "(?:f32)";
267 case Integer s32 -> "(?:s32)";
268 case Long s64 -> "(?:s64)";
269 case Double f64 -> "(?:f64)";
270 case Buffer buffer -> "(?:" +SchemaBuilder.schema(buffer)+")";
271 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
272 };
273 if (i > 0) {
274 argSchema.append(",");
275 }
276 argSchema.append(schemas[i]);
277 }
278 String schemaStr = argSchema.toString();
279 ArgArray argArray = schema.allocate(accelerator,args.length,schemaStr.length() + 1);
280 byte[] schemaStrBytes = schemaStr.getBytes();
281 for (int i = 0; i < schemaStrBytes.length; i++) {
282 argArray.schemaBytes(i, schemaStrBytes[i]);
283 }
284 argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
285 update(argArray,kernelCallGraph,args);
286 return argArray;
287 }
288
289 static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
290 Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations();
291 List<BufferTagger.AccessType> bufferAccessList = kernelCallGraph.bufferAccessList;
292 for (int i = 0; i < args.length; i++) {
293 Object argObject = args[i];
294 Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
295 arg.idx(i);
296 switch (argObject) {
297 case Boolean z1 -> arg.z1(z1);
298 case Byte s8 -> arg.s8(s8);
299 case Short s16 -> arg.s16(s16);
300 case Character u16 -> arg.u16(u16);
301 case Float f32 -> arg.f32(f32);
302 case Integer s32 -> arg.s32(s32);
303 case Long s64 -> arg.s64(s64);
304 case Double f64 -> arg.f64(f64);
305 case Buffer buffer -> {
306 Annotation[] annotations = parameterAnnotations[i];
307 byte accessByte = UNKNOWN_BYTE;
308 if (annotations.length > 0) {
309 for (Annotation annotation : annotations) {
310 accessByte = switch (annotation) {
311 case RO ro-> Arg.Value.Buf.RO_BYTE;
312 case RW rw -> Arg.Value.Buf.RW_BYTE;
313 case WO wo -> Arg.Value.Buf.WO_BYTE;
314 default -> throw new IllegalStateException("Unexpected value: " + annotation);
315 };
316 }
317 }else{
318 throw new IllegalArgumentException("Argument " + i + " has no access annotations");
319 }
320 MemorySegment segment = Buffer.getMemorySegment(buffer);
321 arg.variant((byte) '&');
322 Arg.Value value = arg.value();
323 Arg.Value.Buf buf = value.buf();
324 buf.address(segment);
325 buf.bytes(segment.byteSize());
326 buf.access(accessByte);
327
328 assert bufferAccessList.get(i).value == accessByte: "buffer tagging mismatch: "
329 + kernelCallGraph.entrypoint.getMethod().getParameters()[i].toString()
330 + " in " + kernelCallGraph.entrypoint.getMethod().getName()
331 + " annotated as " + BufferTagger.convertAccessType(accessByte)
332 + " but tagged as " + bufferAccessList.get(i).name();
333 }
334 default -> throw new IllegalStateException("Unexpected value: " + argObject);
335 }
336 }
337 }
338
339
340 default String getSchemaBytes() {
341 byte[] bytes = new byte[schemaLen() + 1];
342 for (int i = 0; i < schemaLen(); i++) {
343 bytes[i] = schemaBytes(i);
344 }
345 bytes[bytes.length - 1] = '0';
346 return new String(bytes);
347 }
348
349
350 default String dump() {
351 StringBuilder dump = new StringBuilder();
352 dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
353 for (int argIndex = 0; argIndex < argc(); argIndex++) {
354 Arg arg = arg(argIndex);
355 dump.append(arg.asString()).append("\n");
356 }
357 return dump.toString();
358 }
359
360 }