1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 #pragma once
 26 #define CUDA_TYPES
 27 #ifdef __APPLE__
 28 
 29 #define LongUnsignedNewline "%llu\n"
 30 #define Size_tNewline "%lu\n"
 31 #define LongHexNewline "(0x%llx)\n"
 32 #define alignedMalloc(size, alignment) memalign(alignment, size)
 33 #define SNPRINTF snprintf
 34 #else
 35 
 36 #include <malloc.h>
 37 
 38 #define LongHexNewline "(0x%lx)\n"
 39 #define LongUnsignedNewline "%lu\n"
 40 #define Size_tNewline "%lu\n"
 41 #if defined (_WIN32)
 42 #include "windows.h"
 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
 44 #define SNPRINTF _snprintf
 45 #else
 46 #define alignedMalloc(size, alignment) memalign(alignment, size)
 47 #define SNPRINTF  snprintf
 48 #endif
 49 #endif
 50 
 51 #include <iostream>
 52 #include <cuda.h>
 53 #include <builtin_types.h>
 54 //#include <cuda_runtime_api.h>
 55 #define CUDA_TYPES
 56 
 57 #include "shared.h"
 58 
 59 #include <fstream>
 60 
 61 #include<vector>
 62 #include <thread>
 63 
 64 struct WHERE{
 65     const char* f;
 66     int l;
 67     cudaError_enum e;
 68     const char* t;
 69     void report() const{
 70         if (e == CUDA_SUCCESS){
 71            // std::cout << t << "  OK at " << f << " line " << l << std::endl;
 72         }else {
 73             const char *buf;
 74             cuGetErrorName(e, &buf);
 75             std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< "      " << f << " line " << l << std::endl;
 76             exit(-1);
 77         }
 78     }
 79 };
 80 class PtxSource: public Text  {
 81 public:
 82     PtxSource();
 83     PtxSource(size_t len);
 84     PtxSource(size_t len, char *text);
 85     PtxSource(char *text);
 86     ~PtxSource() = default;
 87     static PtxSource *nvcc(const char *cudaSource, size_t len);
 88 };
 89 class CudaSource:public Text  {
 90 public:
 91     CudaSource(size_t len, char *text, bool isCopy);
 92     CudaSource(size_t len);
 93     CudaSource(char* text);
 94     CudaSource();
 95     ~CudaSource() = default;
 96 };
 97 
 98 class CudaBackend : public Backend {
 99 public:
100 class CudaQueue: public Backend::Queue {
101     public:
102          std::thread::id streamCreationThread;
103         CUstream cuStream;
104         CudaQueue(Backend *backend);
105         void init();
106          void wait() override;
107 
108          void release() override;
109 
110          void computeStart() override;
111 
112          void computeEnd() override;
113 
114          void copyToDevice(Buffer *buffer) override;
115 
116          void copyFromDevice(Buffer *buffer) override;
117 
118         virtual void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;
119 
120         virtual ~CudaQueue();
121 
122 };
123 
124     class CudaBuffer : public Backend::Buffer {
125     public:
126         CUdeviceptr devicePtr;
127         CudaBuffer(Backend *backend, BufferState *bufferState);
128         virtual ~CudaBuffer();
129     };
130 
131     class CudaModule : public Backend::CompilationUnit {
132 
133     private:
134         CUmodule module;
135         CudaSource cudaSource;
136         PtxSource ptxSource;
137         Log log;
138 
139     public:
140         class CudaKernel : public Backend::CompilationUnit::Kernel {
141 
142         public:
143             bool setArg(KernelArg *arg) override;
144             bool setArg(KernelArg *arg, Buffer *buffer) override;
145             CudaKernel(Backend::CompilationUnit *program, char* name, CUfunction function);
146             ~CudaKernel() override;
147             static CudaKernel * of(long kernelHandle);
148             static CudaKernel * of(Backend::CompilationUnit::Kernel *kernel);
149 
150             CUfunction function;
151             void *argslist[100];
152         };
153         CudaModule(Backend *backend, char *cudaSrc,   char *log, bool ok, CUmodule module);
154         ~CudaModule();
155         static CudaModule * of(long moduleHandle);
156         static CudaModule * of(Backend::CompilationUnit *compilationUnit);
157         Kernel *getKernel(int nameLen, char *name);
158         CudaKernel *getCudaKernel(char *name);
159         CudaKernel *getCudaKernel(int nameLen, char *name);
160         bool programOK();
161     };
162 
163 private:
164     CUresult initStatus;
165     CUdevice device;
166     CUcontext context;
167 public:
168     void info();
169     CudaModule * compile(CudaSource *cudaSource);
170     CudaModule * compile(CudaSource &cudaSource);
171     PtxSource *nvcc(CudaSource *cudaSource);
172     Backend::CompilationUnit * compile(int len, char *source) override;
173     void computeStart() override;
174     void computeEnd() override;
175     CudaBuffer * getOrCreateBuffer(BufferState *bufferState) override;
176     bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override;
177 
178     CudaBackend(int mode);
179     CudaBackend();
180 
181     ~CudaBackend();
182     static CudaBackend * of(long backendHandle);
183     static CudaBackend * of(Backend *backend);
184 };
185 
186