1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 #pragma once
26 #define CUDA_TYPES
27 #ifdef __APPLE__
28
29 #define LongUnsignedNewline "%llu\n"
30 #define Size_tNewline "%lu\n"
31 #define LongHexNewline "(0x%llx)\n"
32 #define alignedMalloc(size, alignment) memalign(alignment, size)
33 #define SNPRINTF snprintf
34 #else
35
36 #include <malloc.h>
37
38 #define LongHexNewline "(0x%lx)\n"
39 #define LongUnsignedNewline "%lu\n"
40 #define Size_tNewline "%lu\n"
41 #if defined (_WIN32)
42 #include "windows.h"
43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
44 #define SNPRINTF _snprintf
45 #else
46 #define alignedMalloc(size, alignment) memalign(alignment, size)
47 #define SNPRINTF snprintf
48 #endif
49 #endif
50
51 #include <iostream>
52 #include <cuda.h>
53 #include <builtin_types.h>
54
55 #include "shared.h"
56
57 #include <fstream>
58 #include <thread>
59
60 struct WHERE{
61 const char* f;
62 int l;
63 cudaError_enum e;
64 const char* t;
65 void report() const {
66 if (e != CUDA_SUCCESS){
67 const char *buf;
68 cuGetErrorName(e, &buf);
69 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl;
70 exit(-1);
71 }
72 }
73 };
74
75 #define CUDA_CHECK(err, functionName) { \
76 WHERE{.f =__FILE__, \
77 .l=__LINE__, \
78 .e = err, \
79 .t = functionName \
80 }.report(); \
81 }
82
83 class PtxSource final : public Text {
84 public:
85 PtxSource();
86 explicit PtxSource(size_t len);
87 PtxSource(size_t len, char *text);
88 PtxSource(size_t len, char *text, bool isCopy);
89 explicit PtxSource(char *text);
90 ~PtxSource() override = default;
91 };
92
93 class CudaSource final :public Text {
94 public:
95 CudaSource(size_t len, char *text, bool isCopy, bool lineinfo);
96 bool lineInfo() const;
97 explicit CudaSource(size_t len);
98 explicit CudaSource(char* text);
99 CudaSource();
100 ~CudaSource() override = default;
101 private:
102 bool _lineInfo = false;
103 };
104
105 class CudaBackend final : public Backend {
106 public:
107 class CudaQueue final : public Backend::Queue {
108 public:
109 std::thread::id streamCreationThread;
110 CUstream cuStream;
111 explicit CudaQueue(Backend *backend);
112 void init();
113 void wait() override;
114
115 void release() override;
116
117 void computeStart() override;
118
119 void computeEnd() override;
120
121 void copyToDevice(Buffer *buffer) override;
122
123 void copyFromDevice(Buffer *buffer) override;
124
125 int estimateThreadsPerBlock(int dimensions);
126
127 int estimateThreadsPerBlock(int dimensions, int globalSizePerDimension, int localSize);
128
129 void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;
130
131 ~CudaQueue() override;
132 };
133
134 class CudaBuffer final : public Buffer {
135 public:
136 CUdeviceptr devicePtr;
137 CudaBuffer(Backend *backend, BufferState *bufferState);
138 ~CudaBuffer() override;
139 };
140
141 class CudaModule final : public CompilationUnit {
142 CUmodule module;
143 CudaSource cudaSource;
144 PtxSource ptxSource;
145 Log log;
146
147 public:
148 class CudaKernel final : public Kernel {
149
150 public:
151 bool setArg(KernelArg *arg) override;
152 bool setArg(KernelArg *arg, Buffer *buffer) override;
153 CudaKernel(Backend::CompilationUnit *program, char* name, CUfunction function);
154 ~CudaKernel() override;
155 static CudaKernel * of(long kernelHandle);
156 static CudaKernel * of(Backend::CompilationUnit::Kernel *kernel);
157
158 CUfunction function;
159 void *argslist[100]{};
160 };
161 CudaModule(Backend *backend, char *cudaSrc, char *log, bool ok, CUmodule module);
162 ~CudaModule() override;
163 static CudaModule * of(long moduleHandle);
164 //static CudaModule * of(CompilationUnit *compilationUnit);
165 Kernel *getKernel(int nameLen, char *name) override;
166 CudaKernel *getCudaKernel(char *name);
167 CudaKernel *getCudaKernel(int nameLen, char *name);
168 bool programOK();
169 };
170
171 private:
172 CUresult initStatus;
173 CUdevice device;
174 CUcontext context;
175 public:
176 void showDeviceInfo() override;
177 CudaModule * compile(const CudaSource *cudaSource);
178 CudaModule * compile(const CudaSource &cudaSource);
179 CudaModule * compile(const PtxSource *ptxSource);
180 CudaModule * compile(const PtxSource &ptxSource);
181 static PtxSource *nvcc(const CudaSource *cudaSource);
182 CompilationUnit * compile(int len, char *source) override;
183 void computeStart() override;
184 void computeEnd() override;
185 CudaBuffer * getOrCreateBuffer(BufferState *bufferState) override;
186 bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override;
187
188 explicit CudaBackend(int mode);
189
190 ~CudaBackend() override;
191 static CudaBackend * of(long backendHandle);
192 static CudaBackend * of(Backend *backend);
193 };