1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 #pragma once
 27 
 28 #include <iostream>
 29 #include <map>
 30 #include <vector>
 31 #include <cstdio>
 32 #include <cstring>
 33 #include <unistd.h>
 34 #include <sys/time.h>
 35 #include <iostream>
 36 #include <iomanip>
 37 #include <bitset>
 38 #include <stack>
 39 
 40 #ifdef __APPLE__
 41    #define SNPRINTF snprintf
 42 #else
 43    #include <malloc.h>
 44    #if defined (_WIN32)
 45       #include "windows.h"
 46       #define SNPRINTF _snprintf
 47    #else
 48       #define SNPRINTF  snprintf
 49    #endif
 50 #endif
 51 
 52 typedef char s8_t;
 53 typedef char byte;
 54 typedef char boolean;
 55 typedef char z1_t;
 56 typedef unsigned char u8_t;
 57 typedef short s16_t;
 58 typedef unsigned short u16_t;
 59 typedef unsigned int u32_t;
 60 typedef int s32_t;
 61 typedef float f32_t;
 62 typedef double f64_t;
 63 typedef long s64_t;
 64 typedef unsigned long u64_t;
 65 
 66 extern void hexdump(void *ptr, int buflen);
 67  // hat iface buffer bits
 68  // hat iface bffa   bits
 69  // 4a7 1face bffa   b175
 70 
 71  #define UNKNOWN_BYTE 0
 72  #define RO_BYTE (1<<1)
 73  #define WO_BYTE (1<<2)
 74  #define RW_BYTE (RO_BYTE|WO_BYTE)
 75 
 76  struct Buffer_s {
 77     void *memorySegment;   // Address of a Buffer/MemorySegment
 78     long sizeInBytes;     // The size of the memory segment in bytes
 79     u8_t access;          // see hat/buffer/ArgArray.java  UNKNOWN_BYTE=0, RO_BYTE =1<<1,WO_BYTE =1<<2,RW_BYTE =RO_BYTE|WO_BYTE;
 80 } ;
 81 
 82  union Value_u {
 83     boolean z1;  // 'Z'
 84     u8_t s8;  // 'B'
 85     u16_t u16;  // 'C'
 86     s16_t s16;  // 'S'
 87     u16_t x16;  // 'C' or 'S"
 88     s32_t s32;  // 'I'
 89     s32_t x32;  // 'I' or 'F'
 90     f32_t f32; // 'F'
 91     f64_t f64; // 'D'
 92     s64_t s64; // 'J'
 93     s64_t x64; // 'D' or 'J'
 94     Buffer_s buffer; // '&'
 95 } ;
 96 
 97  struct Arg_s {
 98     u32_t idx;          // 0..argc
 99     u8_t variant;      // which variant 'I','Z','S','J','F', '&' implies Buffer/MemorySegment
100     u8_t pad8[8];
101     Value_u value;
102     u8_t pad6[6];
103     size_t size(){
104        size_t sz;
105        switch(variant){
106           case 'I': case'F':sz= sizeof(u32_t);break;
107           case 'S': case 'C':sz= sizeof(u16_t);break;
108           case 'D':case 'J':return sizeof(u64_t);break;
109           case 'B':return sizeof (u8_t);break;
110         default:
111            std::cerr <<"Bad variant " <<variant << "arg::size" << std::endl;
112            exit(1);
113 
114       }
115 
116       return sz;
117       }
118 };
119 
120  struct BufferState_s{
121    static const long  MAGIC =0x4a71facebffab175;
122    static const int   BIT_HOST_NEW =0x00000004;
123    static const int   BIT_DEVICE_NEW =0x00000008;
124    static const int   BIT_HOST_DIRTY =0x00000001;
125    static const int   BIT_DEVICE_DIRTY =0x00000002;
126 
127 
128    long magic1;
129    int bits;
130    int unused;
131    void *vendorPtr;
132    long magic2;
133    bool ok(){
134       return ((magic1 == MAGIC) && (magic2 == MAGIC));
135    }
136 
137    void assignBits(int bitBits) {
138       bits=bitBits;
139    }
140    void setBits(int bitBits) {
141       bits|=bitBits;
142    }
143    void  xorBits(int bitsToReset) {
144       // say bits = 0b0111 (7) and bitz = 0b0100 (4)
145       int xored = bits^bitsToReset;  // xored = 0b0011 (3)
146       bits =  xored;
147    }
148     void  resetBits(int bitsToReset) {
149          // say bits = 0b0111 (7) and bitz = 0b0100 (4)
150          bits = bits&~bitsToReset;  // xored = 0b0011 (3)
151          //bits =  xored;
152       }
153    int getBits() {
154       return bits;
155    }
156    bool areBitsSet(int bitBits) {
157       return (bits&bitBits)==bitBits;
158    }
159    void setHostDirty(){
160       setBits(BIT_HOST_DIRTY);
161    }
162    bool isHostDirty(){
163       return  areBitsSet(BIT_HOST_DIRTY);
164    }
165    void clearHostDirty(){
166       resetBits(BIT_HOST_DIRTY);
167    }
168    bool isHostNew(){
169       return  areBitsSet(BIT_HOST_NEW);
170    }
171    void clearHostNew(){
172       resetBits(BIT_HOST_NEW);
173    }
174    bool isHostNewOrDirty() {
175       return areBitsSet(BIT_HOST_NEW|BIT_HOST_DIRTY);
176    }
177 
178    void setDeviceDirty(){
179       setBits(BIT_DEVICE_DIRTY);
180    }
181 
182    bool isDeviceDirty(){
183       return areBitsSet(BIT_DEVICE_DIRTY);
184    }
185    void clearDeviceDirty(){
186       resetBits(BIT_DEVICE_DIRTY);
187    }
188 
189 
190    void dump(const char *msg){
191      if (ok()){
192         printf("{%s, bits:%08x, unused:%08x, vendorPtr:%016lx}\n", msg, bits, unused, (long)vendorPtr);
193      }else{
194         printf("%s bad magic \n", msg);
195         printf("(magic1:%016lx,", magic1);
196         printf("{%s, bits:%08x, unused:%08x, vendorPtr:%016lx}", msg, bits, unused, (long)vendorPtr);
197         printf("magic2:%016lx)\n", magic2);
198      }
199    }
200    static BufferState_s* of(void *ptr, size_t sizeInBytes){
201       return (BufferState_s*) (((char*)ptr)+sizeInBytes-sizeof(BufferState_s));
202    }
203 
204      static BufferState_s* of(Arg_s *arg){ // access?
205         return BufferState_s::of(
206            arg->value.buffer.memorySegment,
207            arg->value.buffer.sizeInBytes
208            );
209       }
210 };
211 
212 struct ArgArray_s {
213     u32_t argc;
214     u8_t pad12[12];
215     Arg_s argv[0/*argc*/];
216 };
217 
218 class ArgSled {
219 private:
220     ArgArray_s *argArray;
221 public:
222     int argc() {
223         return argArray->argc;
224     }
225 
226     Arg_s *arg(int n) {
227         Arg_s *a = (argArray->argv + n);
228         return a;
229     }
230 
231     void hexdumpArg(int n) {
232         hexdump(arg(n), sizeof(Arg_s));
233     }
234 
235     void dumpArg(int n) {
236         Arg_s *a = arg(n);
237         int idx = (int) a->idx;
238         std::cout << "arg[" << idx << "]";
239         char variant = (char) a->variant;
240         switch (variant) {
241             case 'F':
242                 std::cout << " f32 " << a->value.f32 << std::endl;
243                 break;
244             case 'I':
245                 std::cout << " s32 " << a->value.s32 << std::endl;
246                 break;
247             case 'D':
248                 std::cout << " f64 " << a->value.f64 << std::endl;
249                 break;
250             case 'J':
251                 std::cout << " s64 " << a->value.s64 << std::endl;
252                 break;
253             case 'C':
254                 std::cout << " u16 " << a->value.u16 << std::endl;
255                 break;
256             case 'S':
257                 std::cout << " s16 " << a->value.s32 << std::endl;
258                 break;
259             case 'Z':
260                 std::cout << " z1 " << a->value.z1 << std::endl;
261                 break;
262             case '&':
263                 std::cout << " buffer {"
264                           << " void *address = 0x" << std::hex << (long) a->value.buffer.memorySegment << std::dec
265                           << ", long bytesSize= 0x" << std::hex << (long) a->value.buffer.sizeInBytes << std::dec
266                           << ", char access= 0x" << std::hex << (unsigned char) a->value.buffer.access << std::dec
267                           << "}" << std::endl;
268                 break;
269             default:
270                 std::cout << (char) variant << std::endl;
271                 break;
272         }
273     }
274 
275     void *afterArgsPtrPtr() {
276         Arg_s *a = arg(argc());
277         return (void *) a;
278    }
279 
280     int *schemaLenPtr() {
281         int *schemaLenP = (int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *) */);
282         return schemaLenP;
283     }
284 
285     int schemaLen() {
286         return *schemaLenPtr();
287     }
288 
289     char *schema() {
290         int *schemaLenP = ((int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *)*/) + 1);
291         return (char *) schemaLenP;
292     }
293 
294     ArgSled(ArgArray_s *argArray)
295             : argArray(argArray) {}
296 };
297 
298 
299 class Timer {
300     struct timeval startTV, endTV;
301 public:
302     unsigned long elapsed_us;
303 
304     void start() {
305         gettimeofday(&startTV, NULL);
306     }
307 
308     unsigned long end() {
309         gettimeofday(&endTV, NULL);
310         elapsed_us = (endTV.tv_sec - startTV.tv_sec) * 1000000;      // sec to us
311         elapsed_us += (endTV.tv_usec - startTV.tv_usec);
312         return elapsed_us;
313     }
314 };
315 
316 
317 class BuildInfo {
318 public:
319     char *src;
320     char *log;
321     bool ok;
322 
323     BuildInfo(char *src, char *log, bool ok)
324             : src(src), log(log), ok(ok) {
325     }
326 
327     ~BuildInfo() {
328         if (src) {
329             delete[] src;
330         }
331         if (log) {
332             delete[] log;
333         }
334     }
335 
336 };
337 
338 
339 //extern "C" void dumpArgArray(void *ptr);
340 
341 
342 extern void hexdump(void *ptr, int buflen);
343 
344 class Sled {
345 public:
346     static void show(std::ostream &out, void *argArray);
347 };
348 
349 
350 class NDRange {
351 public:
352     int x;
353     int maxX;
354 };
355 
356 class Backend {
357 public:
358 
359     class Program {
360     public:
361         class Kernel {
362         public:
363             class Buffer {
364             public:
365                 Kernel *kernel;
366                 Arg_s *arg;
367 
368                 virtual void copyToDevice() = 0;
369 
370                 virtual void copyFromDevice() = 0;
371 
372                 Buffer(Kernel *kernel, Arg_s *arg)
373                         : kernel(kernel), arg(arg) {
374                 }
375 
376                 virtual ~Buffer() {}
377             };
378 
379             char *name;// strduped!
380 
381             Program *program;
382 
383             virtual long ndrange(void *argArray) = 0;
384             static char *copy(char *name){
385                 size_t len =::strlen(name);
386                 char *buf = new char[len+1];
387                 memcpy(buf, name, len);
388                 buf[len]='\0';
389                 return buf;
390             }
391             Kernel(Program *program, char *name)
392                     : program(program), name(copy(name)) {
393             }
394 
395             virtual ~Kernel() {
396                 if (name) {
397                     delete[] name;
398                 }
399             }
400         };
401 
402     public:
403         Backend *backend;
404         BuildInfo *buildInfo;
405 
406         virtual long getKernel(int nameLen, char *name) = 0;
407 
408         virtual bool programOK() = 0;
409 
410         Program(Backend *backend, BuildInfo *buildInfo)
411                 : backend(backend), buildInfo(buildInfo) {
412         }
413 
414         virtual ~Program() {
415             if (buildInfo != nullptr) {
416                 delete buildInfo;
417             }
418         };
419 
420     };
421     int mode;
422 
423     Backend(int mode)
424             : mode(mode){}
425 
426     virtual void info() = 0;
427 
428      virtual void computeStart() = 0;
429       virtual void computeEnd() = 0;
430 
431     virtual int getMaxComputeUnits() = 0;
432 
433     virtual long compileProgram(int len, char *source) = 0;
434 
435     virtual bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength)=0;
436 
437     virtual ~Backend() {};
438 };
439 
440 extern "C" void info(long backendHandle);
441 extern "C" int getMaxComputeUnits(long backendHandle);
442 extern "C" long compileProgram(long backendHandle, int len, char *source);
443 extern "C" long getKernel(long programHandle, int len, char *name);
444 extern "C" void releaseBackend(long backendHandle);
445 extern "C" void releaseProgram(long programHandle);
446 extern "C" bool programOK(long programHandle);
447 extern "C" void releaseKernel(long kernelHandle);
448 extern "C" long ndrange(long kernelHandle, void *argArray);
449 extern "C" void computeStart(long backendHandle);
450 extern "C" void computeEnd(long backendHandle);
451 extern "C" bool getBufferFromDeviceIfDirty(long backendHandle, long memorySegmentHandle, long memorySegmentLength);
452