1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 #pragma once 27 28 #include <iostream> 29 #include <map> 30 #include <vector> 31 #include <cstdio> 32 #include <cstring> 33 #include <unistd.h> 34 #include <sys/time.h> 35 #include <iostream> 36 #include <iomanip> 37 #include <bitset> 38 #include <stack> 39 40 #ifdef __APPLE__ 41 #define SNPRINTF snprintf 42 #else 43 #include <malloc.h> 44 #if defined (_WIN32) 45 #include "windows.h" 46 #define SNPRINTF _snprintf 47 #else 48 #define SNPRINTF snprintf 49 #endif 50 #endif 51 52 typedef char s8_t; 53 typedef char byte; 54 typedef char boolean; 55 typedef char z1_t; 56 typedef unsigned char u8_t; 57 typedef short s16_t; 58 typedef unsigned short u16_t; 59 typedef unsigned int u32_t; 60 typedef int s32_t; 61 typedef float f32_t; 62 typedef double f64_t; 63 typedef long s64_t; 64 typedef unsigned long u64_t; 65 66 extern void hexdump(void *ptr, int buflen); 67 // hat iface buffer bits 68 // hat iface bffa bits 69 // 4a7 1face bffa b175 70 71 #define UNKNOWN_BYTE 0 72 #define RO_BYTE (1<<1) 73 #define WO_BYTE (1<<2) 74 #define RW_BYTE (RO_BYTE|WO_BYTE) 75 76 struct Buffer_s { 77 void *memorySegment; // Address of a Buffer/MemorySegment 78 long sizeInBytes; // The size of the memory segment in bytes 79 u8_t access; // see hat/buffer/ArgArray.java UNKNOWN_BYTE=0, RO_BYTE =1<<1,WO_BYTE =1<<2,RW_BYTE =RO_BYTE|WO_BYTE; 80 } ; 81 82 union Value_u { 83 boolean z1; // 'Z' 84 u8_t s8; // 'B' 85 u16_t u16; // 'C' 86 s16_t s16; // 'S' 87 u16_t x16; // 'C' or 'S" 88 s32_t s32; // 'I' 89 s32_t x32; // 'I' or 'F' 90 f32_t f32; // 'F' 91 f64_t f64; // 'D' 92 s64_t s64; // 'J' 93 s64_t x64; // 'D' or 'J' 94 Buffer_s buffer; // '&' 95 } ; 96 97 struct Arg_s { 98 u32_t idx; // 0..argc 99 u8_t variant; // which variant 'I','Z','S','J','F', '&' implies Buffer/MemorySegment 100 u8_t pad8[8]; 101 Value_u value; 102 u8_t pad6[6]; 103 size_t size(){ 104 size_t sz; 105 switch(variant){ 106 case 'I': case'F':sz= sizeof(u32_t);break; 107 case 'S': case 'C':sz= sizeof(u16_t);break; 108 case 'D':case 'J':return sizeof(u64_t);break; 109 case 'B':return sizeof (u8_t);break; 110 default: 111 std::cerr <<"Bad variant " <<variant << "arg::size" << std::endl; 112 exit(1); 113 114 } 115 116 return sz; 117 } 118 }; 119 120 struct BufferState_s{ 121 static const long MAGIC =0x4a71facebffab175; 122 static const int BIT_HOST_NEW =0x00000004; 123 static const int BIT_DEVICE_NEW =0x00000008; 124 static const int BIT_HOST_DIRTY =0x00000001; 125 static const int BIT_DEVICE_DIRTY =0x00000002; 126 127 128 long magic1; 129 int bits; 130 int unused; 131 void *vendorPtr; 132 long magic2; 133 bool ok(){ 134 return ((magic1 == MAGIC) && (magic2 == MAGIC)); 135 } 136 137 void assignBits(int bitBits) { 138 bits=bitBits; 139 } 140 void setBits(int bitBits) { 141 bits|=bitBits; 142 } 143 void xorBits(int bitsToReset) { 144 // say bits = 0b0111 (7) and bitz = 0b0100 (4) 145 int xored = bits^bitsToReset; // xored = 0b0011 (3) 146 bits = xored; 147 } 148 void resetBits(int bitsToReset) { 149 // say bits = 0b0111 (7) and bitz = 0b0100 (4) 150 bits = bits&~bitsToReset; // xored = 0b0011 (3) 151 //bits = xored; 152 } 153 int getBits() { 154 return bits; 155 } 156 bool areBitsSet(int bitBits) { 157 return (bits&bitBits)==bitBits; 158 } 159 void setHostDirty(){ 160 setBits(BIT_HOST_DIRTY); 161 } 162 bool isHostDirty(){ 163 return areBitsSet(BIT_HOST_DIRTY); 164 } 165 void clearHostDirty(){ 166 resetBits(BIT_HOST_DIRTY); 167 } 168 bool isHostNew(){ 169 return areBitsSet(BIT_HOST_NEW); 170 } 171 void clearHostNew(){ 172 resetBits(BIT_HOST_NEW); 173 } 174 bool isHostNewOrDirty() { 175 return areBitsSet(BIT_HOST_NEW|BIT_HOST_DIRTY); 176 } 177 178 void setDeviceDirty(){ 179 setBits(BIT_DEVICE_DIRTY); 180 } 181 182 bool isDeviceDirty(){ 183 return areBitsSet(BIT_DEVICE_DIRTY); 184 } 185 void clearDeviceDirty(){ 186 resetBits(BIT_DEVICE_DIRTY); 187 } 188 189 190 void dump(const char *msg){ 191 if (ok()){ 192 printf("{%s, bits:%08x, unused:%08x, vendorPtr:%016lx}\n", msg, bits, unused, (long)vendorPtr); 193 }else{ 194 printf("%s bad magic \n", msg); 195 printf("(magic1:%016lx,", magic1); 196 printf("{%s, bits:%08x, unused:%08x, vendorPtr:%016lx}", msg, bits, unused, (long)vendorPtr); 197 printf("magic2:%016lx)\n", magic2); 198 } 199 } 200 static BufferState_s* of(void *ptr, size_t sizeInBytes){ 201 return (BufferState_s*) (((char*)ptr)+sizeInBytes-sizeof(BufferState_s)); 202 } 203 204 static BufferState_s* of(Arg_s *arg){ // access? 205 return BufferState_s::of( 206 arg->value.buffer.memorySegment, 207 arg->value.buffer.sizeInBytes 208 ); 209 } 210 }; 211 212 struct ArgArray_s { 213 u32_t argc; 214 u8_t pad12[12]; 215 Arg_s argv[0/*argc*/]; 216 }; 217 218 class ArgSled { 219 private: 220 ArgArray_s *argArray; 221 public: 222 int argc() { 223 return argArray->argc; 224 } 225 226 Arg_s *arg(int n) { 227 Arg_s *a = (argArray->argv + n); 228 return a; 229 } 230 231 void hexdumpArg(int n) { 232 hexdump(arg(n), sizeof(Arg_s)); 233 } 234 235 void dumpArg(int n) { 236 Arg_s *a = arg(n); 237 int idx = (int) a->idx; 238 std::cout << "arg[" << idx << "]"; 239 char variant = (char) a->variant; 240 switch (variant) { 241 case 'F': 242 std::cout << " f32 " << a->value.f32 << std::endl; 243 break; 244 case 'I': 245 std::cout << " s32 " << a->value.s32 << std::endl; 246 break; 247 case 'D': 248 std::cout << " f64 " << a->value.f64 << std::endl; 249 break; 250 case 'J': 251 std::cout << " s64 " << a->value.s64 << std::endl; 252 break; 253 case 'C': 254 std::cout << " u16 " << a->value.u16 << std::endl; 255 break; 256 case 'S': 257 std::cout << " s16 " << a->value.s32 << std::endl; 258 break; 259 case 'Z': 260 std::cout << " z1 " << a->value.z1 << std::endl; 261 break; 262 case '&': 263 std::cout << " buffer {" 264 << " void *address = 0x" << std::hex << (long) a->value.buffer.memorySegment << std::dec 265 << ", long bytesSize= 0x" << std::hex << (long) a->value.buffer.sizeInBytes << std::dec 266 << ", char access= 0x" << std::hex << (unsigned char) a->value.buffer.access << std::dec 267 << "}" << std::endl; 268 break; 269 default: 270 std::cout << (char) variant << std::endl; 271 break; 272 } 273 } 274 275 void *afterArgsPtrPtr() { 276 Arg_s *a = arg(argc()); 277 return (void *) a; 278 } 279 280 int *schemaLenPtr() { 281 int *schemaLenP = (int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *) */); 282 return schemaLenP; 283 } 284 285 int schemaLen() { 286 return *schemaLenPtr(); 287 } 288 289 char *schema() { 290 int *schemaLenP = ((int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *)*/) + 1); 291 return (char *) schemaLenP; 292 } 293 294 ArgSled(ArgArray_s *argArray) 295 : argArray(argArray) {} 296 }; 297 298 299 class Timer { 300 struct timeval startTV, endTV; 301 public: 302 unsigned long elapsed_us; 303 304 void start() { 305 gettimeofday(&startTV, NULL); 306 } 307 308 unsigned long end() { 309 gettimeofday(&endTV, NULL); 310 elapsed_us = (endTV.tv_sec - startTV.tv_sec) * 1000000; // sec to us 311 elapsed_us += (endTV.tv_usec - startTV.tv_usec); 312 return elapsed_us; 313 } 314 }; 315 316 317 class BuildInfo { 318 public: 319 char *src; 320 char *log; 321 bool ok; 322 323 BuildInfo(char *src, char *log, bool ok) 324 : src(src), log(log), ok(ok) { 325 } 326 327 ~BuildInfo() { 328 if (src) { 329 delete[] src; 330 } 331 if (log) { 332 delete[] log; 333 } 334 } 335 336 }; 337 338 339 //extern "C" void dumpArgArray(void *ptr); 340 341 342 extern void hexdump(void *ptr, int buflen); 343 344 class Sled { 345 public: 346 static void show(std::ostream &out, void *argArray); 347 }; 348 349 350 class NDRange { 351 public: 352 int x; 353 int maxX; 354 }; 355 356 class Backend { 357 public: 358 359 class Program { 360 public: 361 class Kernel { 362 public: 363 class Buffer { 364 public: 365 Kernel *kernel; 366 Arg_s *arg; 367 368 virtual void copyToDevice() = 0; 369 370 virtual void copyFromDevice() = 0; 371 372 Buffer(Kernel *kernel, Arg_s *arg) 373 : kernel(kernel), arg(arg) { 374 } 375 376 virtual ~Buffer() {} 377 }; 378 379 char *name;// strduped! 380 381 Program *program; 382 383 virtual long ndrange(void *argArray) = 0; 384 static char *copy(char *name){ 385 size_t len =::strlen(name); 386 char *buf = new char[len+1]; 387 memcpy(buf, name, len); 388 buf[len]='\0'; 389 return buf; 390 } 391 Kernel(Program *program, char *name) 392 : program(program), name(copy(name)) { 393 } 394 395 virtual ~Kernel() { 396 if (name) { 397 delete[] name; 398 } 399 } 400 }; 401 402 public: 403 Backend *backend; 404 BuildInfo *buildInfo; 405 406 virtual long getKernel(int nameLen, char *name) = 0; 407 408 virtual bool programOK() = 0; 409 410 Program(Backend *backend, BuildInfo *buildInfo) 411 : backend(backend), buildInfo(buildInfo) { 412 } 413 414 virtual ~Program() { 415 if (buildInfo != nullptr) { 416 delete buildInfo; 417 } 418 }; 419 420 }; 421 int mode; 422 423 Backend(int mode) 424 : mode(mode){} 425 426 virtual void info() = 0; 427 428 virtual void computeStart() = 0; 429 virtual void computeEnd() = 0; 430 431 virtual int getMaxComputeUnits() = 0; 432 433 virtual long compileProgram(int len, char *source) = 0; 434 435 virtual bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength)=0; 436 437 virtual ~Backend() {}; 438 }; 439 440 extern "C" void info(long backendHandle); 441 extern "C" int getMaxComputeUnits(long backendHandle); 442 extern "C" long compileProgram(long backendHandle, int len, char *source); 443 extern "C" long getKernel(long programHandle, int len, char *name); 444 extern "C" void releaseBackend(long backendHandle); 445 extern "C" void releaseProgram(long programHandle); 446 extern "C" bool programOK(long programHandle); 447 extern "C" void releaseKernel(long kernelHandle); 448 extern "C" long ndrange(long kernelHandle, void *argArray); 449 extern "C" void computeStart(long backendHandle); 450 extern "C" void computeEnd(long backendHandle); 451 extern "C" bool getBufferFromDeviceIfDirty(long backendHandle, long memorySegmentHandle, long memorySegmentLength); 452