1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define CUDA_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <cuda.h> 53 #include <builtin_types.h> 54 55 #include "shared.h" 56 57 #include <fstream> 58 #include <thread> 59 60 struct WHERE{ 61 const char* f; 62 int l; 63 cudaError_enum e; 64 const char* t; 65 void report() const{ 66 if (e == CUDA_SUCCESS){ 67 // std::cout << t << " OK at " << f << " line " << l << std::endl; 68 }else { 69 const char *buf; 70 cuGetErrorName(e, &buf); 71 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl; 72 exit(-1); 73 } 74 } 75 }; 76 77 class PtxSource final : public Text { 78 public: 79 PtxSource(); 80 explicit PtxSource(size_t len); 81 PtxSource(size_t len, char *text); 82 PtxSource(size_t len, char *text, bool isCopy); 83 explicit PtxSource(char *text); 84 ~PtxSource() override = default; 85 }; 86 87 class CudaSource final :public Text { 88 public: 89 CudaSource(size_t len, char *text, bool isCopy); 90 explicit CudaSource(size_t len); 91 explicit CudaSource(char* text); 92 CudaSource(); 93 ~CudaSource() override = default; 94 }; 95 96 class CudaBackend final : public Backend { 97 public: 98 class CudaQueue final : public Backend::Queue { 99 public: 100 std::thread::id streamCreationThread; 101 CUstream cuStream; 102 explicit CudaQueue(Backend *backend); 103 void init(); 104 void wait() override; 105 106 void release() override; 107 108 void computeStart() override; 109 110 void computeEnd() override; 111 112 void copyToDevice(Buffer *buffer) override; 113 114 void copyFromDevice(Buffer *buffer) override; 115 116 void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override; 117 118 ~CudaQueue() override; 119 120 }; 121 122 class CudaBuffer final : public Buffer { 123 public: 124 CUdeviceptr devicePtr; 125 CudaBuffer(Backend *backend, BufferState *bufferState); 126 ~CudaBuffer() override; 127 }; 128 129 class CudaModule final : public CompilationUnit { 130 CUmodule module; 131 CudaSource cudaSource; 132 PtxSource ptxSource; 133 Log log; 134 135 public: 136 class CudaKernel final : public Kernel { 137 138 public: 139 bool setArg(KernelArg *arg) override; 140 bool setArg(KernelArg *arg, Buffer *buffer) override; 141 CudaKernel(Backend::CompilationUnit *program, char* name, CUfunction function); 142 ~CudaKernel() override; 143 static CudaKernel * of(long kernelHandle); 144 static CudaKernel * of(Backend::CompilationUnit::Kernel *kernel); 145 146 CUfunction function; 147 void *argslist[100]{}; 148 }; 149 CudaModule(Backend *backend, char *cudaSrc, char *log, bool ok, CUmodule module); 150 ~CudaModule() override; 151 static CudaModule * of(long moduleHandle); 152 //static CudaModule * of(CompilationUnit *compilationUnit); 153 Kernel *getKernel(int nameLen, char *name) override; 154 CudaKernel *getCudaKernel(char *name); 155 CudaKernel *getCudaKernel(int nameLen, char *name); 156 bool programOK(); 157 }; 158 159 private: 160 CUresult initStatus; 161 CUdevice device; 162 CUcontext context; 163 public: 164 void info() override; 165 CudaModule * compile(const CudaSource *cudaSource); 166 CudaModule * compile(const CudaSource &cudaSource); 167 CudaModule * compile(const PtxSource *ptxSource); 168 CudaModule * compile(const PtxSource &ptxSource); 169 static PtxSource *nvcc(const CudaSource *cudaSource); 170 CompilationUnit * compile(int len, char *source) override; 171 void computeStart() override; 172 void computeEnd() override; 173 CudaBuffer * getOrCreateBuffer(BufferState *bufferState) override; 174 bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override; 175 176 explicit CudaBackend(int mode); 177 178 ~CudaBackend() override; 179 static CudaBackend * of(long backendHandle); 180 static CudaBackend * of(Backend *backend); 181 }; 182 183