1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 #pragma once
 27 
 28 #include <iostream>
 29 #include <map>
 30 #include <vector>
 31 #include <cstdio>
 32 #include <cstring>
 33 #include <unistd.h>
 34 #include <sys/time.h>
 35 #include <iostream>
 36 #include <iomanip>
 37 #include <bitset>
 38 #include <stack>
 39 
 40 #include "strutil.h"
 41 #include "config.h"
 42 
 43 #ifdef __APPLE__
 44 #define SNPRINTF snprintf
 45 #else
 46 #include <malloc.h>
 47 #if defined (_WIN32)
 48 #include "windows.h"
 49 #define SNPRINTF _snprintf
 50 #else
 51 #define SNPRINTF  snprintf
 52 #endif
 53 #endif
 54 
 55 #define ceil_div(x, y) ((x + y - 1) / y)
 56 
 57 typedef char s8_t;
 58 typedef char byte;
 59 typedef char boolean;
 60 typedef char z1_t;
 61 typedef unsigned char u8_t;
 62 typedef short s16_t;
 63 typedef unsigned short u16_t;
 64 typedef unsigned int u32_t;
 65 typedef int s32_t;
 66 typedef float f32_t;
 67 typedef double f64_t;
 68 typedef long s64_t;
 69 typedef unsigned long u64_t;
 70 
 71 extern void hexdump(void *ptr, int buflen);
 72 
 73 class Text {
 74 public:
 75     size_t len;
 76     char *text;
 77     bool isCopy;
 78 
 79     Text(size_t len, char *text, bool isCopy);
 80 
 81     Text(char *text, bool isCopy);
 82 
 83     explicit Text(size_t len);
 84 
 85     void write(const std::string &filename) const;
 86 
 87     void read(const std::string &filename);
 88 
 89     virtual ~Text();
 90 };
 91 
 92 class Log : public Text {
 93 public:
 94     explicit Log(size_t len);
 95 
 96     explicit Log(char *text);
 97 
 98     ~Log() override = default;
 99 };
100 
101 #define UNKNOWN_BYTE 0
102 #define RO_BYTE (1<<1)
103 #define WO_BYTE (1<<2)
104 #define RW_BYTE (RO_BYTE|WO_BYTE)
105 
106 struct Buffer_s {
107     void *memorySegment; // Address of a Buffer/MemorySegment
108     long sizeInBytes;    // The size of the memory segment in bytes
109     u8_t access;         // see hat/buffer/ArgArray.java  UNKNOWN_BYTE=0, RO_BYTE =1<<1,WO_BYTE =1<<2,RW_BYTE =RO_BYTE|WO_BYTE;
110 };
111 
112 union Value_u {
113     boolean z1; // 'Z'
114     u8_t s8; // 'B'
115     u16_t u16; // 'C'
116     s16_t s16; // 'S'
117     u16_t x16; // 'C' or 'S'   // this is never used
118     s32_t s32; // 'I'
119     s32_t x32; // 'I' or 'F'   // this is never used
120     f32_t f32; // 'F'
121     f64_t f64; // 'D'
122     s64_t s64; // 'J'
123     s64_t x64; // 'D' or 'J'   // this is never used
124     Buffer_s buffer; // '&'
125 };
126 
127 struct KernelArg {
128     u32_t idx; // 0..argc
129     u8_t variant; // which variant 'I','Z','S','J','F', '&' implies Buffer/MemorySegment
130     u8_t pad8[8];
131     Value_u value;
132     u8_t pad6[6];
133 
134     size_t size() const {
135         size_t sz;
136         switch (variant) {
137             case 'I':
138             case 'F':
139                 sz = sizeof(u32_t);
140                 break;
141             case 'S':
142             case 'C':
143                 sz = sizeof(u16_t);
144                 break;
145             case 'D':
146             case 'J':
147                 return sizeof(u64_t);
148             case 'B':
149                 return sizeof(u8_t);
150             default:
151                 std::cerr << "Bad variant " << variant << "arg::size" << std::endl;
152                 exit(1);
153         }
154         return sz;
155     }
156 };
157 
158 struct BufferState {
159     static constexpr long MAGIC = 0x4a71facebffab175;   // This magic number is a delimiter to
160                                                         // check the length of the buffer as follows:
161                                                         // *(bufferStart+(bufferLen - sizeof(bufferState)) == MAGIC
162     static constexpr int NO_STATE = 0;
163     static constexpr int NEW_STATE = 1;
164     static constexpr int HOST_OWNED = 2;
165     static constexpr int DEVICE_OWNED = 3;
166     static constexpr int DEVICE_VALID_HOST_HAS_COPY = 4;
167     const static char *stateNames[]; // See below for out of line definition
168 
169     long magic1;
170     void *ptr;
171     long length;
172     int bits;
173     mutable int state;
174     void *vendorPtr;
175     long magic2;
176 
177     bool ok() const {
178         return ((magic1 == MAGIC) && (magic2 == MAGIC));
179     }
180 
181     void setState(int newState) {
182         state = newState;
183     }
184 
185     int getState() const {
186         return state;
187     }
188 
189     void dump(const char *msg) const {
190         if (ok()) {
191             printf("{%s,ptr:%016lx,length: %016lx,  state:%08x, vendorPtr:%016lx}\n", msg, (long) ptr, length, state,
192                    (long) vendorPtr);
193         } else {
194             printf("%s bad magic \n", msg);
195             printf("(magic1:%016lx,", magic1);
196             printf("{%s, ptr:%016lx, length: %016lx,  state:%08x, vendorPtr:%016lx}", msg, (long) ptr, length, state,
197                    (long) vendorPtr);
198             printf("magic2:%016lx)\n", magic2);
199         }
200     }
201 
202     static BufferState *of(void *ptr, size_t sizeInBytes) {
203         return reinterpret_cast<BufferState *>(static_cast<char *>(ptr) + sizeInBytes - sizeof(BufferState));
204     }
205 
206     static BufferState *of(const KernelArg *arg) {
207         // access?
208         BufferState *bufferState = of(
209             arg->value.buffer.memorySegment,
210             arg->value.buffer.sizeInBytes
211         );
212 
213         // Sanity check the buffers
214         // These sanity check finds errors passing memory segments which are not Buffers
215         if (bufferState->ptr != arg->value.buffer.memorySegment) {
216              std::cerr << "Error:  Unexpected initial state for buffer "
217                             << " idx=" << arg->idx
218                             << " bufferState->ptr=0x"<<std::hex<<((long)bufferState->ptr)<<std::dec
219                               << " bufferState->length=0x"<<std::hex<<((long)bufferState->length)<<std::dec
220                             << " arg->value.buffer.memorySegment=0x"<<std::hex<<((long)arg->value.buffer.memorySegment)<<std::dec
221                             << " state=" << bufferState->state << " '"
222                             << stateNames[bufferState->state] << "'"
223                             << " vendorPtr" << bufferState->vendorPtr << std::endl;
224             std::cerr << "The ptr (bufferState->ptr) does not appear to be a arg->value.buffer.memorySegment" << std::endl;
225 
226             // This is A bit brutal to stop the VM? We can throw an exception and handle it in the Java side?
227             std::exit(1);
228         }
229 
230         if ((bufferState->vendorPtr == nullptr) && (bufferState->state != NEW_STATE)) {
231             std::cerr << "Warning:  Unexpected initial state for buffer "
232                     << " idx=" << arg->idx
233                     << " state=" << bufferState->state << " '"
234                     << stateNames[bufferState->state] << "'"
235                     << " vendorPtr" << bufferState->vendorPtr << std::endl;
236             // This is A bit brutal to stop the VM? We can throw an exception and handle it in the Java side?
237             //std::exit(1);
238         }
239         // End of sanity checks
240         return bufferState;
241     }
242 };
243 
244 #ifdef shared_cpp
245 const char *BufferState::stateNames[] = {
246     "NO_STATE",
247     "NEW_STATE",
248     "HOST_OWNED",
249     "DEVICE_OWNED",
250     "DEVICE_VALID_HOST_HAS_COPY"
251 };
252 #endif
253 
254 struct ArgArray_s {
255     u32_t argc;
256     u8_t pad12[12];
257     KernelArg argv[0/*argc*/];
258 };
259 
260 class ArgSled {
261 private:
262     ArgArray_s *argArray;
263 
264 public:
265     int argc() const {
266         return argArray->argc;
267     }
268 
269     KernelArg *arg(int n) const {
270         KernelArg *a = (argArray->argv + n);
271         return a;
272     }
273 
274     void hexdumpArg(int n) const {
275         hexdump(arg(n), sizeof(KernelArg));
276     }
277 
278     void dumpArg(int n) const {
279         KernelArg *a = arg(n);
280         int idx = (int) a->idx;
281         std::cout << "arg[" << idx << "]";
282         char variant = (char) a->variant;
283         switch (variant) {
284             case 'F':
285                 std::cout << " f32 " << a->value.f32 << std::endl;
286                 break;
287             case 'I':
288                 std::cout << " s32 " << a->value.s32 << std::endl;
289                 break;
290             case 'D':
291                 std::cout << " f64 " << a->value.f64 << std::endl;
292                 break;
293             case 'J':
294                 std::cout << " s64 " << a->value.s64 << std::endl;
295                 break;
296             case 'C':
297                 std::cout << " u16 " << a->value.u16 << std::endl;
298                 break;
299             case 'S':
300                 std::cout << " s16 " << a->value.s32 << std::endl;
301                 break;
302             case 'Z':
303                 std::cout << " z1 " << a->value.z1 << std::endl;
304                 break;
305             case '&':
306                 std::cout << " buffer {"
307                         << " void *address = 0x" << std::hex << (long) a->value.buffer.memorySegment << std::dec
308                         << ", long bytesSize= 0x" << std::hex << (long) a->value.buffer.sizeInBytes << std::dec
309                         << ", char access= 0x" << std::hex << (unsigned char) a->value.buffer.access << std::dec
310                         << "}" << std::endl;
311                 break;
312             default:
313                 std::cout << (char) variant << std::endl;
314                 break;
315         }
316     }
317 
318     void *afterArgsPtrPtr() const {
319         KernelArg *a = arg(argc());
320         return (void *) a;
321     }
322 
323     int *schemaLenPtr() const {
324         int *schemaLenP = (int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *) */);
325         return schemaLenP;
326     }
327 
328     int schemaLen() const {
329         return *schemaLenPtr();
330     }
331 
332     char *schema() const {
333         int *schemaLenP = ((int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *)*/) + 1);
334         return (char *) schemaLenP;
335     }
336 
337     explicit ArgSled(ArgArray_s *argArray)
338         : argArray(argArray) {
339     }
340 };
341 
342 
343 class Timer {
344     struct timeval startTV, endTV;
345 
346 public:
347     unsigned long elapsed_us{};
348 
349     Timer(): startTV(), endTV() {
350     }
351 
352     void start() {
353         gettimeofday(&startTV, nullptr);
354     }
355 
356     unsigned long end() {
357         gettimeofday(&endTV, nullptr);
358         elapsed_us = (endTV.tv_sec - startTV.tv_sec) * 1000000; // sec to us
359         elapsed_us += (endTV.tv_usec - startTV.tv_usec);
360         return elapsed_us;
361     }
362 };
363 
364 
365 //extern void hexdump(void *ptr, int buflen);
366 
367 class Sled {
368 public:
369     static void show(std::ostream &out, void *argArray);
370 };
371 
372 class KernelContext {
373 public:
374 
375     // Dimensions of the kernel (1D, 2D or 3D)
376     int dimensions;
377 
378     // global index
379     int gix;
380     int giy;
381     int giz;
382 
383     // global sizes
384     int gsx;
385     int gsy;
386     int gsz;
387 
388     // local index
389     int lix;
390     int liy;
391     int liz;
392 
393     // local size
394     int lsx;
395     int lsy;
396     int lsz;
397 
398     // Group index
399     int bix;
400     int biy;
401     int biz;
402 
403     // Block sizes
404     int bsx;
405     int bsy;
406     int bsz;
407 
408     // Tile Size
409     int tlx;
410     int tly;
411     int tlz;
412 
413     // Warp sizes
414     bool wsx;
415     bool wsy;
416     bool wsz;
417 };
418 
419 class Backend {
420 public:
421     class Config final : public BasicConfig {
422     public:
423         explicit Config(int mode);
424 
425         ~Config() override;
426     };
427 
428     class Buffer {
429     public:
430         Backend *backend;
431         BufferState *bufferState;
432 
433         Buffer(Backend *backend, BufferState *bufferState)
434             : backend(backend), bufferState(bufferState) {
435         }
436 
437         virtual ~Buffer() = default;
438     };
439 
440     class CompilationUnit {
441     public:
442         class Kernel {
443         public:
444             char *name;
445 
446             CompilationUnit *compilationUnit;
447 
448             virtual bool setArg(KernelArg *arg, Buffer *openCLBuffer) = 0;
449 
450             virtual bool setArg(KernelArg *arg) = 0;
451 
452             virtual long ndrange(void *argArray) final;
453 
454             Kernel(CompilationUnit *compilationUnit, char *name)
455                 : name(strutil::clone(name)), compilationUnit(compilationUnit) {
456             }
457 
458             virtual ~Kernel() {
459                 delete[] name;
460             }
461         };
462 
463     public:
464         Backend *backend;
465         char *src;
466         char *log;
467         bool ok;
468 
469         virtual Kernel *getKernel(int nameLen, char *name) = 0;
470 
471         virtual bool compilationUnitOK() final {
472             return ok;
473         }
474 
475         CompilationUnit(Backend *backend, char *src, char *log, bool ok)
476             : backend(backend), src(src), log(log), ok(ok) {
477         }
478 
479         virtual ~CompilationUnit() {
480             delete[] src;
481             delete[] log;
482         };
483     };
484 
485     class Queue {
486     public:
487         Backend *backend;
488 
489         explicit Queue(Backend *backend);
490 
491         virtual void wait() = 0;
492 
493         virtual void release() = 0;
494 
495         virtual void computeStart() = 0;
496 
497         virtual void computeEnd() = 0;
498 
499         virtual void copyToDevice(Buffer *buffer) =0;
500 
501         virtual void copyFromDevice(Buffer *buffer) =0;
502 
503         virtual void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) = 0;
504 
505         virtual ~Queue();
506     };
507 
508     class ProfilableQueue : public Queue {
509     public:
510         static constexpr int START_BIT_IDX = 20;
511         static constexpr int CopyToDeviceBits = 1 << START_BIT_IDX;
512         static constexpr int CopyFromDeviceBits = 1 << 21;
513         static constexpr int NDRangeBits = 1 << 22;
514         static constexpr int StartComputeBits = 1 << 23;
515         static constexpr int EndComputeBits = 1 << 24;
516         static constexpr int EnterKernelDispatchBits = 1 << 25;
517         static constexpr int LeaveKernelDispatchBits = 1 << 26;
518         static constexpr int HasConstCharPtrArgBits = 1 << 27;
519         static constexpr int hasIntArgBits = 1 << 28;
520         static constexpr int END_BIT_IDX = 27;
521 
522         size_t eventMax;
523         size_t eventc;
524         int *eventInfoBits;
525         const char **eventInfoConstCharPtrArgs;
526 
527         virtual void showEvents(int width) = 0;
528 
529         virtual void inc(int bits) = 0;
530 
531         virtual void inc(int bits, const char *arg) = 0;
532 
533         virtual void marker(int bits) = 0;
534 
535         virtual void marker(int bits, const char *arg) = 0;
536 
537 
538         virtual void markAsStartComputeAndInc() = 0;
539 
540         virtual void markAsEndComputeAndInc() = 0;
541 
542         virtual void markAsEnterKernelDispatchAndInc() = 0;
543 
544         virtual void markAsLeaveKernelDispatchAndInc() = 0;
545 
546         ProfilableQueue(Backend *backend, int eventMax)
547             : Queue(backend),
548               eventMax(eventMax),
549               eventInfoBits(new int[eventMax]),
550               eventInfoConstCharPtrArgs(new const char *[eventMax]),
551               eventc(0) {
552         }
553 
554         ~ProfilableQueue() override {
555             delete[]eventInfoBits;
556             delete[]eventInfoConstCharPtrArgs;
557         }
558     };
559 
560     Config *config;
561     Queue *queue;
562 
563     Backend(Config *config, Queue *queue)
564         : config(config), queue(queue) {
565     }
566 
567     virtual Buffer *getOrCreateBuffer(BufferState *bufferState) = 0;
568 
569     virtual void shortDeviceInfo() = 0;
570 
571     virtual void showDeviceInfo() = 0;
572 
573     virtual void computeStart() = 0;
574 
575     virtual void computeEnd() = 0;
576 
577     virtual CompilationUnit *compile(int len, char *source) = 0;
578 
579     virtual bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) = 0;
580 
581     virtual ~Backend() = default;
582 };
583 
584 template<typename T>
585 T *bufferOf(const char *name) {
586     size_t lenIncludingBufferState = sizeof(T);
587     size_t lenExcludingBufferState = lenIncludingBufferState - sizeof(BufferState);
588     T *buffer = reinterpret_cast<T *>(new unsigned char[lenIncludingBufferState]);
589     auto *bufferState = reinterpret_cast<BufferState *>(reinterpret_cast<char *>(buffer) + lenExcludingBufferState);
590     bufferState->magic1 = bufferState->magic2 = BufferState::MAGIC;
591     bufferState->ptr = buffer;
592     bufferState->length = sizeof(T) - sizeof(BufferState);
593     bufferState->state = BufferState::NEW_STATE;
594     bufferState->vendorPtr = nullptr;
595     bufferState->dump(name);
596     return buffer;
597 }