1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.buffer;
26
27 import hat.annotations.Kernel;
28 import hat.callgraph.KernelCallGraph;
29 import optkl.ifacemapper.AccessType;
30 import optkl.ifacemapper.BoundSchema;
31 import optkl.ifacemapper.Buffer;
32 import optkl.ifacemapper.Schema;
33 import optkl.ifacemapper.SchemaBuilder;
34 import optkl.util.carriers.ArenaAndLookupCarrier;
35
36 import java.lang.annotation.Annotation;
37 import java.lang.foreign.MemorySegment;
38
39 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
40 import static java.lang.foreign.ValueLayout.JAVA_INT;
41 import static optkl.ifacemapper.MappableIface.getMemorySegment;
42
43 public interface ArgArray extends Buffer {
44 interface Arg extends Buffer.Struct {
45 interface Value extends Buffer.Union {
46 interface Buf extends Buffer.Struct {
47 MemorySegment address();
48
49 void address(MemorySegment address);
50
51 long bytes();
52
53 void bytes(long bytes);
54
55 byte access();
56
57 void access(byte access);
58 }
59
60 boolean z1();
61
62 void z1(boolean z1);
63
64 byte s8();
65
66 void s8(byte s8);
67
68 char u16();
69
70 void u16(char u16);
71
72 short s16();
73
74 void s16(short s16);
75
76 int u32();
77
78 void u32(int u32);
79
80 int s32();
81
82 void s32(int s32);
83
84 float f32();
85
86 void f32(float f32);
87
88 long u64();
89
90 void u64(long u64);
91
92 long s64();
93
94 void s64(long s64);
95
96 double f64();
97
98 void f64(double f64);
99
100 Buf buf();
101 }
102
103 int idx();
104
105 void idx(int idx);
106
107 byte variant();
108
109 void variant(byte variant);
110
111 Value value();
112
113 default String asString() {
114 return switch (variant()) {
115 case '&' -> Long.toHexString(u64());
116 case 'F' -> Float.toString(f32());
117 case 'I' -> Integer.toString(s32());
118 case 'J' -> Long.toString(s64());
119 case 'D' -> Double.toString(f64());
120 case 'Z' -> Boolean.toString(z1());
121 case 'B' -> Byte.toString(s8());
122 case 'S' -> Short.toString(s16());
123 case 'C' -> Character.toString(u16());
124 default -> throw new IllegalStateException("Variant " + variant() + " not supported for arg");
125 };
126 }
127
128 default boolean z1() {
129 return value().z1();
130 }
131
132 default void z1(boolean z1) {
133 variant((byte) 'Z');
134 value().z1(z1);
135 }
136
137 default byte s8() {
138 return value().s8();
139 }
140
141 default void s8(byte s8) {
142 variant((byte) 'B');
143 value().s8(s8);
144 }
145
146 default char u16() {
147 return value().u16();
148 }
149
150 default void u16(char u16) {
151 variant((byte) 'C');
152 value().u16(u16);
153 }
154
155 default short s16() {
156 return value().s16();
157 }
158
159 default void s16(short s16) {
160 variant((byte) 'S');
161 value().s16(s16);
162 }
163
164 default int s32() {
165 return value().s32();
166 }
167
168 default void s32(int s32) {
169 variant((byte) 'I');
170 value().s32(s32);
171 }
172
173 default float f32() {
174 return value().f32();
175 }
176
177 default void f32(float f32) {
178 variant((byte) 'F');
179 value().f32(f32);
180 }
181
182 default long s64() {
183 return value().s64();
184 }
185
186 default void s64(long s64) {
187 variant((byte) 'J');
188 value().s64(s64);
189 }
190
191 default long u64() {
192 return value().u64();
193 }
194
195 default void u64(long u64) {
196 variant((byte) '&');
197 value().u64(u64);
198 }
199
200 default double f64() {
201 return value().f64();
202 }
203
204 default void f64(double f64) {
205 variant((byte) 'D');
206 value().f64(f64);
207 }
208 }
209
210 int argc();
211
212
213 Arg arg(long idx);
214
215 int schemaLen();
216
217 byte schemaBytes(long idx);
218
219 void schemaBytes(long idx, byte b);
220
221
222 Schema<ArgArray> schema = Schema.of(ArgArray.class, s -> s
223 .arrayLen("argc")
224 .pad((int) (16 - JAVA_INT.byteSize()))
225 .array("arg", arg -> arg
226 .fields("idx", "variant")
227 .pad((int) (16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize()))
228 .field("value", val -> val
229 .fields("z1", "s8", "u16", "s16", "s32", "u32", "f32", "s64", "u64", "f64")
230 .field("buf", buf -> buf
231 .fields("address", "bytes", "access")
232 .pad((int) (16 - JAVA_BYTE.byteSize()))
233 )
234 )
235 )
236 .arrayLen("schemaLen")
237 .array("schemaBytes")
238 );
239
240
241 static ArgArray create(ArenaAndLookupCarrier cc, KernelCallGraph kernelCallGraph, Object... args) {
242 String[] schemas = new String[args.length];
243 StringBuilder argSchema = new StringBuilder();
244 argSchema.append(args.length);
245 for (int i = 0; i < args.length; i++) {
246 Object argObject = args[i];
247 schemas[i] = switch (argObject) {
248 case Boolean _ -> "(?:z1)";
249 case Byte _ -> "(?:s8)";
250 case Short _ -> "(?:s16)";
251 case Character _ -> "(?:u16)";
252 case Float _ -> "(?:f32)";
253 case Integer _ -> "(?:s32)";
254 case Long _ -> "(?:s64)";
255 case Double _ -> "(?:f64)";
256 case Buffer buffer -> "(?:" + SchemaBuilder.schema(buffer) + ")";
257 default ->
258 throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
259 };
260 if (i > 0) {
261 argSchema.append(",");
262 }
263 argSchema.append(schemas[i]);
264 }
265 String schemaStr = argSchema.toString();
266 ArgArray argArray = BoundSchema.of(cc, schema, args.length, schemaStr.length() + 1).allocate();
267 byte[] schemaStrBytes = schemaStr.getBytes();
268 for (int i = 0; i < schemaStrBytes.length; i++) {
269 argArray.schemaBytes(i, schemaStrBytes[i]);
270 }
271 argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
272 update(argArray, kernelCallGraph, args);
273 return argArray;
274 }
275
276 static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
277 Annotation[][] parameterAnnotations = kernelCallGraph.callDag.entryPoint.method().getParameterAnnotations();
278 var bufferAccessList = kernelCallGraph.bufferAccessList;
279 for (int i = 0; i < args.length; i++) {
280 Object argObject = args[i];
281 Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
282 arg.idx(i);
283 switch (argObject) {
284 case Boolean z1 -> arg.z1(z1);
285 case Byte s8 -> arg.s8(s8);
286 case Short s16 -> arg.s16(s16);
287 case Character u16 -> arg.u16(u16);
288 case Float f32 -> arg.f32(f32);
289 case Integer s32 -> arg.s32(s32);
290 case Long s64 -> arg.s64(s64);
291 case Double f64 -> arg.f64(f64);
292 case Buffer buffer -> {
293 Annotation[] annotations = parameterAnnotations[i];
294 AccessType accessType = AccessType.NA;
295 for (Annotation annotation : annotations) {
296 accessType = AccessType.of(annotation);
297 }
298 MemorySegment segment = getMemorySegment(buffer);
299 arg.variant((byte) '&');
300 Arg.Value value = arg.value();
301 Arg.Value.Buf buf = value.buf();
302 buf.address(segment);
303 buf.bytes(segment.byteSize());
304
305 if (kernelCallGraph.callDag.entryPoint.method().getAnnotation(Kernel.class) != null) {
306 // If the annotation is present, then we keep the accessor defined for each parameter
307 buf.access(accessType.value);
308 } else {
309 // otherwise, we rely on the buffer-tagger to set the accessor
310 buf.access(bufferAccessList.get(i).value);
311 }
312 }
313 default -> throw new IllegalStateException("Unexpected value: " + argObject);
314 }
315 }
316 }
317
318 default String getSchemaBytes() {
319 byte[] bytes = new byte[schemaLen() + 1];
320 for (int i = 0; i < schemaLen(); i++) {
321 bytes[i] = schemaBytes(i);
322 }
323 bytes[bytes.length - 1] = '0';
324 return new String(bytes);
325 }
326
327 default String dump() {
328 StringBuilder dump = new StringBuilder();
329 dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
330 for (int argIndex = 0; argIndex < argc(); argIndex++) {
331 Arg arg = arg(argIndex);
332 dump.append(arg.asString()).append("\n");
333 }
334 return dump.toString();
335 }
336 }