1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.buffer; 26 27 import hat.Accelerator; 28 import hat.ComputeContext; 29 import hat.callgraph.KernelCallGraph; 30 import hat.ifacemapper.Schema; 31 32 import java.lang.annotation.Annotation; 33 import java.lang.foreign.MemorySegment; 34 import java.lang.foreign.ValueLayout; 35 import java.lang.invoke.MethodHandles; 36 import java.nio.ByteOrder; 37 38 import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE; 39 import static java.lang.foreign.ValueLayout.JAVA_BYTE; 40 import static java.lang.foreign.ValueLayout.JAVA_INT; 41 42 public interface ArgArray extends Buffer { 43 interface Arg extends Buffer.Struct{ 44 interface Value extends Buffer.Union{ 45 interface Buf extends Buffer.Struct{ 46 byte UNKNOWN_BYTE=(byte)0; 47 byte RO_BYTE =(byte)1<<1; 48 byte WO_BYTE =(byte)1<<2; 49 byte RW_BYTE =RO_BYTE|WO_BYTE; 50 MemorySegment address(); 51 void address(MemorySegment address); 52 long bytes(); 53 void bytes(long bytes); 54 byte access(); 55 void access(byte access); 56 } 57 boolean z1(); 58 59 void z1(boolean z1); 60 61 byte s8(); 62 63 void s8(byte s8); 64 65 char u16(); 66 67 void u16(char u16); 68 69 short s16(); 70 71 void s16(short s16); 72 73 int u32(); 74 75 void u32(int u32); 76 77 int s32(); 78 79 void s32(int s32); 80 81 float f32(); 82 83 void f32(float f32); 84 85 long u64(); 86 87 void u64(long u64); 88 89 long s64(); 90 91 void s64(long s64); 92 93 double f64(); 94 95 void f64(double f64); 96 97 Buf buf(); 98 } 99 100 int idx(); 101 void idx(int idx); 102 103 byte variant(); 104 void variant(byte variant); 105 106 Value value(); 107 108 default String asString() { 109 switch (variant()) { 110 case '&': 111 return Long.toHexString(u64()); 112 case 'F': 113 return Float.toString(f32()); 114 case 'I': 115 return Integer.toString(s32()); 116 case 'J': 117 return Long.toString(s64()); 118 case 'D': 119 return Double.toString(f64()); 120 case 'Z': 121 return Boolean.toString(z1()); 122 case 'B': 123 return Byte.toString(s8()); 124 case 'S': 125 return Short.toString(s16()); 126 case 'C': 127 return Character.toString(u16()); 128 } 129 throw new IllegalStateException("what is this"); 130 } 131 132 default boolean z1() { 133 return value().z1(); 134 } 135 136 default void z1(boolean z1) { 137 variant((byte) 'Z'); 138 value().z1(z1); 139 } 140 141 default byte s8() { 142 return value().s8(); 143 } 144 145 default void s8(byte s8) { 146 variant((byte) 'B'); 147 value().s8(s8); 148 } 149 150 default char u16() { 151 return value().u16(); 152 } 153 154 default void u16(char u16) { 155 variant((byte) 'C'); 156 value().u16(u16); 157 } 158 159 default short s16() { 160 return value().s16(); 161 } 162 163 default void s16(short s16) { 164 variant((byte) 'S'); 165 value().s16(s16); 166 } 167 168 default int s32() { 169 return value().s32(); 170 } 171 172 default void s32(int s32) { 173 variant((byte) 'I'); 174 value().s32(s32); 175 } 176 177 default float f32() { 178 return value().f32(); 179 } 180 181 default void f32(float f32) { 182 variant((byte) 'F'); 183 value().f32(f32); 184 } 185 186 default long s64() { 187 return value().s64(); 188 } 189 190 default void s64(long s64) { 191 variant((byte) 'J'); 192 value().s64(s64); 193 } 194 195 default long u64() { 196 return value().u64(); 197 } 198 199 default void u64(long u64) { 200 variant((byte) '&'); 201 value().u64(u64); 202 } 203 204 default double f64() { 205 return value().f64(); 206 } 207 208 default void f64(double f64) { 209 variant((byte) 'D'); 210 value().f64(f64); 211 } 212 } 213 214 int argc(); 215 216 217 Arg arg(long idx); 218 219 int schemaLen(); 220 221 byte schemaBytes(long idx); 222 void schemaBytes(long idx, byte b); 223 224 Schema<ArgArray> schema = Schema.of(ArgArray.class, s->s 225 .arrayLen("argc").pad(12).array("arg", arg->arg 226 .fields("idx", "variant") 227 .pad(11/*(int)(16 - JAVA_INT.byteSize() - JAVA_BYTE.byteSize())*/) 228 .field("value", value->value 229 .fields("z1","s8","u16","s16","s32","u32","f32","s64","u64","f64") 230 .field("buf", buf->buf 231 .fields("address","bytes",/*"vendorPtr",*/"access") 232 .pad((int)(16 - JAVA_BYTE.byteSize()/* - JAVA_BYTE.byteSize()*/)) 233 ) 234 ) 235 ) 236 // .field("vendorPtr") 237 .arrayLen("schemaLen").array("schemaBytes") 238 ); 239 240 static String valueLayoutToSchemaString(ValueLayout valueLayout) { 241 String descriptor = valueLayout.carrier().descriptorString(); 242 String schema = switch (descriptor) { 243 case "Z" -> "Z"; 244 case "B" -> "S"; 245 case "C" -> "U"; 246 case "S" -> "S"; 247 case "I" -> "S"; 248 case "F" -> "F"; 249 case "D" -> "F"; 250 case "J" -> "S"; 251 default -> throw new IllegalStateException("Unexpected value: " + descriptor); 252 } + valueLayout.byteSize() * 8; 253 return (valueLayout.order().equals(ByteOrder.LITTLE_ENDIAN)) ? schema.toLowerCase() : schema; 254 } 255 256 static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph, Object... args) { 257 String[] schemas = new String[args.length]; 258 StringBuilder argSchema = new StringBuilder(); 259 argSchema.append(args.length); 260 for (int i = 0; i < args.length; i++) { 261 Object argObject = args[i]; 262 schemas[i] = switch (argObject) { 263 case Boolean z1 -> "(?:z1)"; 264 case Byte s8 -> "(?:s8)"; 265 case Short s16 -> "(?:s16)"; 266 case Character u16 -> "(?:u16)"; 267 case Float f32 -> "(?:f32)"; 268 case Integer s32 -> "(?:s32)"; 269 case Long s64 -> "(?:s64)"; 270 case Double f64 -> "(?:f64)"; 271 case Buffer buffer -> "(?:" +SchemaBuilder.schema(buffer)+")"; 272 default -> throw new IllegalStateException("Unexpected value: " + argObject + " Did you pass an interface which is neither a Complete or Incomplete buffer"); 273 }; 274 if (i > 0) { 275 argSchema.append(","); 276 } 277 argSchema.append(schemas[i]); 278 } 279 String schemaStr = argSchema.toString(); 280 ArgArray argArray = schema.allocate(accelerator,args.length,schemaStr.length() + 1); 281 byte[] schemaStrBytes = schemaStr.getBytes(); 282 for (int i = 0; i < schemaStrBytes.length; i++) { 283 argArray.schemaBytes(i, schemaStrBytes[i]); 284 } 285 argArray.schemaBytes(schemaStrBytes.length, (byte) 0); 286 update(argArray,kernelCallGraph,args); 287 return argArray; 288 } 289 290 static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) { 291 Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations(); 292 for (int i = 0; i < args.length; i++) { 293 Object argObject = args[i]; 294 Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all 295 arg.idx(i); 296 switch (argObject) { 297 case Boolean z1 -> arg.z1(z1); 298 case Byte s8 -> arg.s8(s8); 299 case Short s16 -> arg.s16(s16); 300 case Character u16 -> arg.u16(u16); 301 case Float f32 -> arg.f32(f32); 302 case Integer s32 -> arg.s32(s32); 303 case Long s64 -> arg.s64(s64); 304 case Double f64 -> arg.f64(f64); 305 case Buffer buffer -> { 306 Annotation[] annotations = parameterAnnotations[i]; 307 byte accessByte = UNKNOWN_BYTE; 308 if (annotations.length > 0) { 309 for (Annotation annotation : annotations) { 310 accessByte = switch (annotation) { 311 case RO ro-> Arg.Value.Buf.RO_BYTE; 312 case RW rw -> Arg.Value.Buf.RW_BYTE; 313 case WO wo -> Arg.Value.Buf.WO_BYTE; 314 default -> throw new IllegalStateException("Unexpected value: " + annotation); 315 }; 316 } 317 }else{ 318 throw new IllegalArgumentException("Argument " + i + " has no access annotations"); 319 } 320 MemorySegment segment = Buffer.getMemorySegment(buffer); 321 arg.variant((byte) '&'); 322 Arg.Value value = arg.value(); 323 Arg.Value.Buf buf = value.buf(); 324 buf.address(segment); 325 buf.bytes(segment.byteSize()); 326 buf.access(accessByte); 327 } 328 default -> throw new IllegalStateException("Unexpected value: " + argObject); 329 } 330 } 331 } 332 333 334 default String getSchemaBytes() { 335 byte[] bytes = new byte[schemaLen() + 1]; 336 for (int i = 0; i < schemaLen(); i++) { 337 bytes[i] = schemaBytes(i); 338 } 339 bytes[bytes.length - 1] = '0'; 340 return new String(bytes); 341 } 342 343 344 default String dump() { 345 StringBuilder dump = new StringBuilder(); 346 dump.append("SchemaBytes:").append(getSchemaBytes()).append("\n"); 347 for (int argIndex = 0; argIndex < argc(); argIndex++) { 348 Arg arg = arg(argIndex); 349 dump.append(arg.asString()).append("\n"); 350 } 351 return dump.toString(); 352 } 353 354 }