1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.buffer;
26
27 import optkl.ifacemapper.AccessType;
28 import hat.callgraph.KernelCallGraph;
29 import optkl.ifacemapper.BoundSchema;
30 import optkl.util.carriers.ArenaAndLookupCarrier;
31 import optkl.ifacemapper.Buffer;
32 import optkl.ifacemapper.MappableIface;
33 import optkl.ifacemapper.Schema;
34 import optkl.ifacemapper.SchemaBuilder;
35
36 import java.lang.annotation.Annotation;
37 import java.lang.foreign.MemorySegment;
38 import static java.lang.foreign.ValueLayout.JAVA_BYTE;
39 import static java.lang.foreign.ValueLayout.JAVA_INT;
40
41 public interface ArgArray extends Buffer {
42 interface Arg extends Buffer.Struct{
43 interface Value extends Buffer.Union{
44 interface Buf extends Buffer.Struct{
45 MemorySegment address();
46 void address(MemorySegment address);
47 long bytes();
48 void bytes(long bytes);
49 byte access();
50 void access(byte access);
51 }
52 boolean z1();
53
54 void z1(boolean z1);
55
56 byte s8();
57
58 void s8(byte s8);
59
60 char u16();
61
62 void u16(char u16);
63
64 short s16();
65
66 void s16(short s16);
67
68 int u32();
69
70 void u32(int u32);
71
72 int s32();
73
74 void s32(int s32);
75
76 float f32();
77
78 void f32(float f32);
79
80 long u64();
81
82 void u64(long u64);
83
84 long s64();
85
86 void s64(long s64);
87
88 double f64();
89
90 void f64(double f64);
91
92 Buf buf();
93 }
94
95 int idx();
96 void idx(int idx);
97
98 byte variant();
99 void variant(byte variant);
100
101 Value value();
102
103 default String asString() {
104 return switch (variant()) {
105 case '&'-> Long.toHexString(u64());
106 case 'F'-> Float.toString(f32());
107 case 'I'-> Integer.toString(s32());
108 case 'J'-> Long.toString(s64());
109 case 'D'-> Double.toString(f64());
110 case 'Z'-> Boolean.toString(z1());
111 case 'B'-> Byte.toString(s8());
112 case 'S'-> Short.toString(s16());
113 case 'C'-> Character.toString(u16());
114 default-> throw new IllegalStateException("what is this");
115 };
116 }
117
118 default boolean z1() {
119 return value().z1();
120 }
121
122 default void z1(boolean z1) {
123 variant((byte) 'Z');
124 value().z1(z1);
125 }
126
127 default byte s8() {
128 return value().s8();
129 }
130
131 default void s8(byte s8) {
132 variant((byte) 'B');
133 value().s8(s8);
134 }
135
136 default char u16() {
137 return value().u16();
138 }
139
140 default void u16(char u16) {
141 variant((byte) 'C');
142 value().u16(u16);
143 }
144
145 default short s16() {
146 return value().s16();
147 }
148
149 default void s16(short s16) {
150 variant((byte) 'S');
151 value().s16(s16);
152 }
153
154 default int s32() {
155 return value().s32();
156 }
157
158 default void s32(int s32) {
159 variant((byte) 'I');
160 value().s32(s32);
161 }
162
163 default float f32() {
164 return value().f32();
165 }
166
167 default void f32(float f32) {
168 variant((byte) 'F');
169 value().f32(f32);
170 }
171
172 default long s64() {
173 return value().s64();
174 }
175
176 default void s64(long s64) {
177 variant((byte) 'J');
178 value().s64(s64);
179 }
180
181 default long u64() {
182 return value().u64();
183 }
184
185 default void u64(long u64) {
186 variant((byte) '&');
187 value().u64(u64);
188 }
189
190 default double f64() {
191 return value().f64();
192 }
193
194 default void f64(double f64) {
195 variant((byte) 'D');
196 value().f64(f64);
197 }
198 }
199
200 int argc();
201
202
203 Arg arg(long idx);
204
205 int schemaLen();
206
207 byte schemaBytes(long idx);
208 void schemaBytes(long idx, byte b);
209
210
211 Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s
212 .arrayLen("argc")
213 .pad((int)(16-JAVA_INT.byteSize()))
214 .array("arg", arg->arg
215 .fields("idx", "variant")
216 .pad((int)(16-JAVA_INT.byteSize()-JAVA_BYTE.byteSize()))
217 .field("value", val->val
218 .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64")
219 .field("buf", buf->buf
220 .fields("address","bytes","access")
221 .pad((int)(16 - JAVA_BYTE.byteSize()))
222 )
223 )
224 )
225 .arrayLen("schemaLen")
226 .array("schemaBytes")
227 );
228
229
230 static ArgArray create(ArenaAndLookupCarrier cc, KernelCallGraph kernelCallGraph, Object... args) {
231 String[] schemas = new String[args.length];
232 StringBuilder argSchema = new StringBuilder();
233 argSchema.append(args.length);
234 for (int i = 0; i < args.length; i++) {
235 Object argObject = args[i];
236 schemas[i] = switch (argObject) {
237 case Boolean z1 -> "(?:z1)";
238 case Byte s8 -> "(?:s8)";
239 case Short s16 -> "(?:s16)";
240 case Character u16 -> "(?:u16)";
241 case Float f32 -> "(?:f32)";
242 case Integer s32 -> "(?:s32)";
243 case Long s64 -> "(?:s64)";
244 case Double f64 -> "(?:f64)";
245 case Buffer buffer -> "(?:" + SchemaBuilder.schema(buffer)+")";
246 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer");
247 };
248 if (i > 0) {
249 argSchema.append(",");
250 }
251 argSchema.append(schemas[i]);
252 }
253 String schemaStr = argSchema.toString();
254 ArgArray argArray = BoundSchema.of(cc ,schema,args.length,schemaStr.length()+1).allocate();
255 byte[] schemaStrBytes = schemaStr.getBytes();
256 for (int i = 0; i < schemaStrBytes.length; i++) {
257 argArray.schemaBytes(i, schemaStrBytes[i]);
258 }
259 argArray.schemaBytes(schemaStrBytes.length, (byte) 0);
260 update(argArray,kernelCallGraph,args);
261 return argArray;
262 }
263
264 static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
265 Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.method().getParameterAnnotations();
266 var bufferAccessList = kernelCallGraph.state.bufferAccessList;
267 for (int i = 0; i < args.length; i++) {
268 Object argObject = args[i];
269 Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
270 arg.idx(i);
271 switch (argObject) {
272 case Boolean z1 -> arg.z1(z1);
273 case Byte s8 -> arg.s8(s8);
274 case Short s16 -> arg.s16(s16);
275 case Character u16 -> arg.u16(u16);
276 case Float f32 -> arg.f32(f32);
277 case Integer s32 -> arg.s32(s32);
278 case Long s64 -> arg.s64(s64);
279 case Double f64 -> arg.f64(f64);
280 case Buffer buffer -> {
281 Annotation[] annotations = parameterAnnotations[i];
282 AccessType accessType = AccessType.NA;
283 if (annotations.length > 0) {
284 for (Annotation annotation : annotations) {
285 accessType = AccessType.of(annotation);
286 }
287 } else {
288 throw new IllegalArgumentException("Argument " + i + " has no access annotations");
289 }
290 MemorySegment segment = MappableIface.getMemorySegment(buffer);
291 arg.variant((byte) '&');
292 Arg.Value value = arg.value();
293 Arg.Value.Buf buf = value.buf();
294 buf.address(segment);
295 buf.bytes(segment.byteSize());
296 buf.access(bufferAccessList.get(i).value); // buf.access(accessType.value);
297 assert bufferAccessList.get(i).value == accessType.value :
298 (kernelCallGraph.entrypoint.method().getParameters()[i].toString()
299 + " in " + kernelCallGraph.entrypoint.method().getName()
300 + ": buffertagger " + bufferAccessList.get(i).value
301 + " doesn't match " + accessType.value);
302 }
303 default -> throw new IllegalStateException("Unexpected value: " + argObject);
304 }
305 }
306 }
307
308
309 default String getSchemaBytes() {
310 byte[] bytes = new byte[schemaLen() + 1];
311 for (int i = 0; i < schemaLen(); i++) {
312 bytes[i] = schemaBytes(i);
313 }
314 bytes[bytes.length - 1] = '0';
315 return new String(bytes);
316 }
317
318
319 default String dump() {
320 StringBuilder dump = new StringBuilder();
321 dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n");
322 for (int argIndex = 0; argIndex < argc(); argIndex++) {
323 Arg arg = arg(argIndex);
324 dump.append(arg.asString()).append("\n");
325 }
326 return dump.toString();
327 }
328
329 }