1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 #pragma once
 27 
 28 #include <iostream>
 29 #include <map>
 30 #include <vector>
 31 #include <cstdio>
 32 #include <cstring>
 33 #include <unistd.h>
 34 #include <sys/time.h>
 35 #include <iostream>
 36 #include <iomanip>
 37 #include <bitset>
 38 #include <stack>
 39 
 40 #include "strutil.h"
 41 #include "config.h"
 42 
 43 #ifdef __APPLE__
 44 #define SNPRINTF snprintf
 45 #else
 46 #include <malloc.h>
 47 #if defined (_WIN32)
 48 #include "windows.h"
 49 #define SNPRINTF _snprintf
 50 #else
 51 #define SNPRINTF  snprintf
 52 #endif
 53 #endif
 54 
 55 typedef char s8_t;
 56 typedef char byte;
 57 typedef char boolean;
 58 typedef char z1_t;
 59 typedef unsigned char u8_t;
 60 typedef short s16_t;
 61 typedef unsigned short u16_t;
 62 typedef unsigned int u32_t;
 63 typedef int s32_t;
 64 typedef float f32_t;
 65 typedef double f64_t;
 66 typedef long s64_t;
 67 typedef unsigned long u64_t;
 68 
 69 extern void hexdump(void *ptr, int buflen);
 70 
 71 class Text {
 72 public:
 73     size_t len;
 74     char *text;
 75     bool isCopy;
 76 
 77     Text(size_t len, char *text, bool isCopy);
 78 
 79     Text(char *text, bool isCopy);
 80 
 81     explicit Text(size_t len);
 82 
 83     void write(const std::string &filename) const;
 84 
 85     void read(const std::string &filename);
 86 
 87     virtual ~Text();
 88 };
 89 
 90 class Log : public Text {
 91 public:
 92     explicit Log(size_t len);
 93 
 94     explicit Log(char *text);
 95 
 96     ~Log() override = default;
 97 };
 98 
 99 #define UNKNOWN_BYTE 0
100 #define RO_BYTE (1<<1)
101 #define WO_BYTE (1<<2)
102 #define RW_BYTE (RO_BYTE|WO_BYTE)
103 
104 struct Buffer_s {
105     void *memorySegment; // Address of a Buffer/MemorySegment
106     long sizeInBytes;    // The size of the memory segment in bytes
107     u8_t access;         // see hat/buffer/ArgArray.java  UNKNOWN_BYTE=0, RO_BYTE =1<<1,WO_BYTE =1<<2,RW_BYTE =RO_BYTE|WO_BYTE;
108 };
109 
110 union Value_u {
111     boolean z1; // 'Z'
112     u8_t s8; // 'B'
113     u16_t u16; // 'C'
114     s16_t s16; // 'S'
115     u16_t x16; // 'C' or 'S'   // this is never used
116     s32_t s32; // 'I'
117     s32_t x32; // 'I' or 'F'   // this is never used
118     f32_t f32; // 'F'
119     f64_t f64; // 'D'
120     s64_t s64; // 'J'
121     s64_t x64; // 'D' or 'J'   // this is never used
122     Buffer_s buffer; // '&'
123 };
124 
125 struct KernelArg {
126     u32_t idx; // 0..argc
127     u8_t variant; // which variant 'I','Z','S','J','F', '&' implies Buffer/MemorySegment
128     u8_t pad8[8];
129     Value_u value;
130     u8_t pad6[6];
131 
132     size_t size() const {
133         size_t sz;
134         switch (variant) {
135             case 'I':
136             case 'F':
137                 sz = sizeof(u32_t);
138                 break;
139             case 'S':
140             case 'C':
141                 sz = sizeof(u16_t);
142                 break;
143             case 'D':
144             case 'J':
145                 return sizeof(u64_t);
146             case 'B':
147                 return sizeof(u8_t);
148             default:
149                 std::cerr << "Bad variant " << variant << "arg::size" << std::endl;
150                 exit(1);
151         }
152         return sz;
153     }
154 };
155 
156 struct BufferState {
157     static constexpr long MAGIC = 0x4a71facebffab175;   // This magic number is a delimiter to
158                                                         // check the length of the buffer as follows:
159                                                         // *(bufferStart+(bufferLen - sizeof(bufferState)) == MAGIC
160     static constexpr int NO_STATE = 0;
161     static constexpr int NEW_STATE = 1;
162     static constexpr int HOST_OWNED = 2;
163     static constexpr int DEVICE_OWNED = 3;
164     static constexpr int DEVICE_VALID_HOST_HAS_COPY = 4;
165     const static char *stateNames[]; // See below for out of line definition
166 
167     long magic1;
168     void *ptr;
169     long length;
170     int bits;
171     mutable int state;
172     void *vendorPtr;
173     long magic2;
174 
175     bool ok() const {
176         return ((magic1 == MAGIC) && (magic2 == MAGIC));
177     }
178 
179     void setState(int newState) {
180         state = newState;
181     }
182 
183     int getState() const {
184         return state;
185     }
186 
187     void dump(const char *msg) const {
188         if (ok()) {
189             printf("{%s,ptr:%016lx,length: %016lx,  state:%08x, vendorPtr:%016lx}\n", msg, (long) ptr, length, state,
190                    (long) vendorPtr);
191         } else {
192             printf("%s bad magic \n", msg);
193             printf("(magic1:%016lx,", magic1);
194             printf("{%s, ptr:%016lx, length: %016lx,  state:%08x, vendorPtr:%016lx}", msg, (long) ptr, length, state,
195                    (long) vendorPtr);
196             printf("magic2:%016lx)\n", magic2);
197         }
198     }
199 
200     static BufferState *of(void *ptr, size_t sizeInBytes) {
201         return reinterpret_cast<BufferState *>(static_cast<char *>(ptr) + sizeInBytes - sizeof(BufferState));
202     }
203 
204     static BufferState *of(const KernelArg *arg) {
205         // access?
206         BufferState *bufferState = of(
207             arg->value.buffer.memorySegment,
208             arg->value.buffer.sizeInBytes
209         );
210 
211         // Sanity check the buffers
212         // These sanity check finds errors passing memory segments which are not Buffers
213         if (bufferState->ptr != arg->value.buffer.memorySegment) {
214              std::cerr << "Error:  Unexpected initial state for buffer "
215                             << " idx=" << arg->idx
216                             << " bufferState->ptr=0x"<<std::hex<<((long)bufferState->ptr)<<std::dec
217                               << " bufferState->length=0x"<<std::hex<<((long)bufferState->length)<<std::dec
218                             << " arg->value.buffer.memorySegment=0x"<<std::hex<<((long)arg->value.buffer.memorySegment)<<std::dec
219                             << " state=" << bufferState->state << " '"
220                             << stateNames[bufferState->state] << "'"
221                             << " vendorPtr" << bufferState->vendorPtr << std::endl;
222             std::cerr << "The ptr (bufferState->ptr) does not appear to be a arg->value.buffer.memorySegment" << std::endl;
223 
224             // This is A bit brutal to stop the VM? We can throw an exception and handle it in the Java side?
225             std::exit(1);
226         }
227 
228         if ((bufferState->vendorPtr == nullptr) && (bufferState->state != NEW_STATE)) {
229             std::cerr << "Warning:  Unexpected initial state for buffer "
230                     << " idx=" << arg->idx
231                     << " state=" << bufferState->state << " '"
232                     << stateNames[bufferState->state] << "'"
233                     << " vendorPtr" << bufferState->vendorPtr << std::endl;
234             // This is A bit brutal to stop the VM? We can throw an exception and handle it in the Java side?
235             //std::exit(1);
236         }
237         // End of sanity checks
238         return bufferState;
239     }
240 };
241 
242 #ifdef shared_cpp
243 const char *BufferState::stateNames[] = {
244     "NO_STATE",
245     "NEW_STATE",
246     "HOST_OWNED",
247     "DEVICE_OWNED",
248     "DEVICE_VALID_HOST_HAS_COPY"
249 };
250 #endif
251 
252 struct ArgArray_s {
253     u32_t argc;
254     u8_t pad12[12];
255     KernelArg argv[0/*argc*/];
256 };
257 
258 class ArgSled {
259 private:
260     ArgArray_s *argArray;
261 
262 public:
263     int argc() const {
264         return argArray->argc;
265     }
266 
267     KernelArg *arg(int n) const {
268         KernelArg *a = (argArray->argv + n);
269         return a;
270     }
271 
272     void hexdumpArg(int n) const {
273         hexdump(arg(n), sizeof(KernelArg));
274     }
275 
276     void dumpArg(int n) const {
277         KernelArg *a = arg(n);
278         int idx = (int) a->idx;
279         std::cout << "arg[" << idx << "]";
280         char variant = (char) a->variant;
281         switch (variant) {
282             case 'F':
283                 std::cout << " f32 " << a->value.f32 << std::endl;
284                 break;
285             case 'I':
286                 std::cout << " s32 " << a->value.s32 << std::endl;
287                 break;
288             case 'D':
289                 std::cout << " f64 " << a->value.f64 << std::endl;
290                 break;
291             case 'J':
292                 std::cout << " s64 " << a->value.s64 << std::endl;
293                 break;
294             case 'C':
295                 std::cout << " u16 " << a->value.u16 << std::endl;
296                 break;
297             case 'S':
298                 std::cout << " s16 " << a->value.s32 << std::endl;
299                 break;
300             case 'Z':
301                 std::cout << " z1 " << a->value.z1 << std::endl;
302                 break;
303             case '&':
304                 std::cout << " buffer {"
305                         << " void *address = 0x" << std::hex << (long) a->value.buffer.memorySegment << std::dec
306                         << ", long bytesSize= 0x" << std::hex << (long) a->value.buffer.sizeInBytes << std::dec
307                         << ", char access= 0x" << std::hex << (unsigned char) a->value.buffer.access << std::dec
308                         << "}" << std::endl;
309                 break;
310             default:
311                 std::cout << (char) variant << std::endl;
312                 break;
313         }
314     }
315 
316     void *afterArgsPtrPtr() const {
317         KernelArg *a = arg(argc());
318         return (void *) a;
319     }
320 
321     int *schemaLenPtr() const {
322         int *schemaLenP = (int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *) */);
323         return schemaLenP;
324     }
325 
326     int schemaLen() const {
327         return *schemaLenPtr();
328     }
329 
330     char *schema() const {
331         int *schemaLenP = ((int *) ((char *) afterArgsPtrPtr() /*+ sizeof(void *)*/) + 1);
332         return (char *) schemaLenP;
333     }
334 
335     explicit ArgSled(ArgArray_s *argArray)
336         : argArray(argArray) {
337     }
338 };
339 
340 
341 class Timer {
342     struct timeval startTV, endTV;
343 
344 public:
345     unsigned long elapsed_us{};
346 
347     Timer(): startTV(), endTV() {
348     }
349 
350     void start() {
351         gettimeofday(&startTV, nullptr);
352     }
353 
354     unsigned long end() {
355         gettimeofday(&endTV, nullptr);
356         elapsed_us = (endTV.tv_sec - startTV.tv_sec) * 1000000; // sec to us
357         elapsed_us += (endTV.tv_usec - startTV.tv_usec);
358         return elapsed_us;
359     }
360 };
361 
362 
363 //extern void hexdump(void *ptr, int buflen);
364 
365 class Sled {
366 public:
367     static void show(std::ostream &out, void *argArray);
368 };
369 
370 class KernelContext {
371 public:
372 
373     // Dimensions of the kernel (1D, 2D or 3D)
374     int dimensions;
375 
376     // global index
377     int gix;
378     int giy;
379     int giz;
380 
381     // global sizes
382     int gsx;
383     int gsy;
384     int gsz;
385 
386     // local index
387     int lix;
388     int liy;
389     int liz;
390 
391     // local size
392     int lsx;
393     int lsy;
394     int lsz;
395 
396     // Group index
397     int bix;
398     int biy;
399     int biz;
400 
401     // Block sizes
402     int bsx;
403     int bsy;
404     int bsz;
405 };
406 
407 class Backend {
408 public:
409     class Config final : public BasicConfig {
410     public:
411         explicit Config(int mode);
412 
413         ~Config() override;
414     };
415 
416     class Buffer {
417     public:
418         Backend *backend;
419         BufferState *bufferState;
420 
421         Buffer(Backend *backend, BufferState *bufferState)
422             : backend(backend), bufferState(bufferState) {
423         }
424 
425         virtual ~Buffer() = default;
426     };
427 
428     class CompilationUnit {
429     public:
430         class Kernel {
431         public:
432             char *name;
433 
434             CompilationUnit *compilationUnit;
435 
436             virtual bool setArg(KernelArg *arg, Buffer *openCLBuffer) = 0;
437 
438             virtual bool setArg(KernelArg *arg) = 0;
439 
440             virtual long ndrange(void *argArray) final;
441 
442             Kernel(CompilationUnit *compilationUnit, char *name)
443                 : name(strutil::clone(name)), compilationUnit(compilationUnit) {
444             }
445 
446             virtual ~Kernel() {
447                 delete[] name;
448             }
449         };
450 
451     public:
452         Backend *backend;
453         char *src;
454         char *log;
455         bool ok;
456 
457         virtual Kernel *getKernel(int nameLen, char *name) = 0;
458 
459         virtual bool compilationUnitOK() final {
460             return ok;
461         }
462 
463         CompilationUnit(Backend *backend, char *src, char *log, bool ok)
464             : backend(backend), src(src), log(log), ok(ok) {
465         }
466 
467         virtual ~CompilationUnit() {
468             delete[] src;
469             delete[] log;
470         };
471     };
472 
473     class Queue {
474     public:
475         Backend *backend;
476 
477         explicit Queue(Backend *backend);
478 
479         virtual void wait() = 0;
480 
481         virtual void release() = 0;
482 
483         virtual void computeStart() = 0;
484 
485         virtual void computeEnd() = 0;
486 
487         virtual void copyToDevice(Buffer *buffer) =0;
488 
489         virtual void copyFromDevice(Buffer *buffer) =0;
490 
491         virtual void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) = 0;
492 
493         virtual ~Queue();
494     };
495 
496     class ProfilableQueue : public Queue {
497     public:
498         static constexpr int START_BIT_IDX = 20;
499         static constexpr int CopyToDeviceBits = 1 << START_BIT_IDX;
500         static constexpr int CopyFromDeviceBits = 1 << 21;
501         static constexpr int NDRangeBits = 1 << 22;
502         static constexpr int StartComputeBits = 1 << 23;
503         static constexpr int EndComputeBits = 1 << 24;
504         static constexpr int EnterKernelDispatchBits = 1 << 25;
505         static constexpr int LeaveKernelDispatchBits = 1 << 26;
506         static constexpr int HasConstCharPtrArgBits = 1 << 27;
507         static constexpr int hasIntArgBits = 1 << 28;
508         static constexpr int END_BIT_IDX = 27;
509 
510         size_t eventMax;
511         size_t eventc;
512         int *eventInfoBits;
513         const char **eventInfoConstCharPtrArgs;
514 
515         virtual void showEvents(int width) = 0;
516 
517         virtual void inc(int bits) = 0;
518 
519         virtual void inc(int bits, const char *arg) = 0;
520 
521         virtual void marker(int bits) = 0;
522 
523         virtual void marker(int bits, const char *arg) = 0;
524 
525 
526         virtual void markAsStartComputeAndInc() = 0;
527 
528         virtual void markAsEndComputeAndInc() = 0;
529 
530         virtual void markAsEnterKernelDispatchAndInc() = 0;
531 
532         virtual void markAsLeaveKernelDispatchAndInc() = 0;
533 
534         ProfilableQueue(Backend *backend, int eventMax)
535             : Queue(backend),
536               eventMax(eventMax),
537               eventInfoBits(new int[eventMax]),
538               eventInfoConstCharPtrArgs(new const char *[eventMax]),
539               eventc(0) {
540         }
541 
542         ~ProfilableQueue() override {
543             delete[]eventInfoBits;
544             delete[]eventInfoConstCharPtrArgs;
545         }
546     };
547 
548     Config *config;
549     Queue *queue;
550 
551     Backend(Config *config, Queue *queue)
552         : config(config), queue(queue) {
553     }
554 
555     virtual Buffer *getOrCreateBuffer(BufferState *bufferState) = 0;
556 
557     virtual void showDeviceInfo() = 0;
558 
559     virtual void computeStart() = 0;
560 
561     virtual void computeEnd() = 0;
562 
563     virtual CompilationUnit *compile(int len, char *source) = 0;
564 
565     virtual bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) = 0;
566 
567     virtual ~Backend() = default;
568 };
569 
570 template<typename T>
571 T *bufferOf(const char *name) {
572     size_t lenIncludingBufferState = sizeof(T);
573     size_t lenExcludingBufferState = lenIncludingBufferState - sizeof(BufferState);
574     T *buffer = reinterpret_cast<T *>(new unsigned char[lenIncludingBufferState]);
575     auto *bufferState = reinterpret_cast<BufferState *>(reinterpret_cast<char *>(buffer) + lenExcludingBufferState);
576     bufferState->magic1 = bufferState->magic2 = BufferState::MAGIC;
577     bufferState->ptr = buffer;
578     bufferState->length = sizeof(T) - sizeof(BufferState);
579     bufferState->state = BufferState::NEW_STATE;
580     bufferState->vendorPtr = nullptr;
581     bufferState->dump(name);
582     return buffer;
583 }