1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 #pragma once
 26 #define CUDA_TYPES
 27 #ifdef __APPLE__
 28 
 29 #define LongUnsignedNewline "%llu\n"
 30 #define Size_tNewline "%lu\n"
 31 #define LongHexNewline "(0x%llx)\n"
 32 #define alignedMalloc(size, alignment) memalign(alignment, size)
 33 #define SNPRINTF snprintf
 34 #else
 35 
 36 #include <malloc.h>
 37 
 38 #define LongHexNewline "(0x%lx)\n"
 39 #define LongUnsignedNewline "%lu\n"
 40 #define Size_tNewline "%lu\n"
 41 #if defined (_WIN32)
 42 #include "windows.h"
 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
 44 #define SNPRINTF _snprintf
 45 #else
 46 #define alignedMalloc(size, alignment) memalign(alignment, size)
 47 #define SNPRINTF  snprintf
 48 #endif
 49 #endif
 50 
 51 #include <iostream>
 52 #include <cuda.h>
 53 #include <builtin_types.h>
 54 
 55 #include "shared.h"
 56 
 57 #include <fstream>
 58 #include <thread>
 59 
 60 struct WHERE{
 61     const char* f;
 62     int l;
 63     cudaError_enum e;
 64     const char* t;
 65     void report() const {
 66         if (e != CUDA_SUCCESS){
 67             const char *buf;
 68             cuGetErrorName(e, &buf);
 69             std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< "      " << f << " line " << l << std::endl;
 70             exit(-1);
 71         }
 72     }
 73 };
 74 
 75 #define CUDA_CHECK(err, functionName) { \
 76     WHERE{.f =__FILE__, \
 77           .l=__LINE__, \
 78           .e = err, \
 79           .t = functionName \
 80          }.report(); \
 81 }
 82 
 83 class PtxSource final : public Text  {
 84 public:
 85     PtxSource();
 86     explicit PtxSource(size_t len);
 87     PtxSource(size_t len, char *text);
 88     PtxSource(size_t len, char *text, bool isCopy);
 89     explicit PtxSource(char *text);
 90     ~PtxSource() override = default;
 91 };
 92 
 93 class CudaSource final :public Text  {
 94 public:
 95     CudaSource(size_t len, char *text, bool isCopy, bool lineinfo);
 96     bool lineInfo() const;
 97     explicit CudaSource(size_t len);
 98     explicit CudaSource(char* text);
 99     CudaSource();
100     ~CudaSource() override = default;
101 private:
102     bool _lineInfo = false;
103 };
104 
105 class CudaBackend final : public Backend {
106 public:
107 class CudaQueue final : public Backend::Queue {
108     public:
109         std::thread::id streamCreationThread;
110         CUstream cuStream;
111         explicit CudaQueue(Backend *backend);
112         void init();
113         void wait() override;
114 
115          void release() override;
116 
117          void computeStart() override;
118 
119          void computeEnd() override;
120 
121          void copyToDevice(Buffer *buffer) override;
122 
123          void copyFromDevice(Buffer *buffer) override;
124 
125         int estimateThreadsPerBlock(int dimensions);
126 
127         int estimateThreadsPerBlock(int dimensions, int globalSizePerDimension, int localSize);
128 
129         void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;
130 
131         ~CudaQueue() override;
132 };
133 
134     class CudaBuffer final : public Buffer {
135     public:
136         CUdeviceptr devicePtr;
137         CudaBuffer(Backend *backend, BufferState *bufferState);
138         ~CudaBuffer() override;
139     };
140 
141     class CudaModule final : public CompilationUnit {
142         CUmodule module;
143         CudaSource cudaSource;
144         PtxSource ptxSource;
145         Log log;
146 
147     public:
148         class CudaKernel final : public Kernel {
149 
150         public:
151             bool setArg(KernelArg *arg) override;
152             bool setArg(KernelArg *arg, Buffer *buffer) override;
153             CudaKernel(Backend::CompilationUnit *program, char* name, CUfunction function);
154             ~CudaKernel() override;
155             static CudaKernel * of(long kernelHandle);
156             static CudaKernel * of(Backend::CompilationUnit::Kernel *kernel);
157 
158             CUfunction function;
159             void *argslist[100]{};
160         };
161         CudaModule(Backend *backend, char *cudaSrc,   char *log, bool ok, CUmodule module);
162         ~CudaModule() override;
163         static CudaModule * of(long moduleHandle);
164         //static CudaModule * of(CompilationUnit *compilationUnit);
165         Kernel *getKernel(int nameLen, char *name) override;
166         CudaKernel *getCudaKernel(char *name);
167         CudaKernel *getCudaKernel(int nameLen, char *name);
168         bool programOK();
169     };
170 
171 private:
172     CUresult initStatus;
173     CUdevice device;
174     CUcontext context;
175 public:
176     void showDeviceInfo() override;
177     CudaModule * compile(const CudaSource *cudaSource);
178     CudaModule * compile(const CudaSource &cudaSource);
179     CudaModule * compile(const PtxSource *ptxSource);
180     CudaModule * compile(const PtxSource &ptxSource);
181     static PtxSource *nvcc(const CudaSource *cudaSource);
182     CompilationUnit * compile(int len, char *source) override;
183     void computeStart() override;
184     void computeEnd() override;
185     CudaBuffer * getOrCreateBuffer(BufferState *bufferState) override;
186     bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override;
187 
188     explicit CudaBackend(int mode);
189 
190     ~CudaBackend() override;
191     static CudaBackend * of(long backendHandle);
192     static CudaBackend * of(Backend *backend);
193 };