1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 #pragma once
26 #define CUDA_TYPES
27 #ifdef __APPLE__
28
29 #define LongUnsignedNewline "%llu\n"
30 #define Size_tNewline "%lu\n"
31 #define LongHexNewline "(0x%llx)\n"
32 #define alignedMalloc(size, alignment) memalign(alignment, size)
33 #define SNPRINTF snprintf
34 #else
35
36 #include <malloc.h>
37
38 #define LongHexNewline "(0x%lx)\n"
39 #define LongUnsignedNewline "%lu\n"
40 #define Size_tNewline "%lu\n"
41 #if defined (_WIN32)
42 #include "windows.h"
43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
44 #define SNPRINTF _snprintf
45 #else
46 #define alignedMalloc(size, alignment) memalign(alignment, size)
47 #define SNPRINTF snprintf
48 #endif
49 #endif
50
51 #include <iostream>
52 #include <cuda.h>
53 #include <builtin_types.h>
54
55 #include "shared.h"
56
57 #include <fstream>
58 #include <thread>
59
60 struct WHERE{
61 const char* f;
62 int l;
63 cudaError_enum e;
64 const char* t;
65 void report() const {
66 if (e != CUDA_SUCCESS){
67 const char *buf;
68 cuGetErrorName(e, &buf);
69 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl;
70 exit(-1);
71 }
72 }
73 };
74
75 #define CUDA_CHECK(err, functionName) { \
76 WHERE{.f =__FILE__, \
77 .l=__LINE__, \
78 .e = err, \
79 .t = functionName \
80 }.report(); \
81 }
82
83 class PtxSource final : public Text {
84 public:
85 PtxSource();
86 explicit PtxSource(size_t len);
87 PtxSource(size_t len, char *text);
88 PtxSource(size_t len, char *text, bool isCopy);
89 explicit PtxSource(char *text);
90 ~PtxSource() override = default;
91 };
92
93 class CudaSource final :public Text {
94 public:
95 CudaSource(size_t len, char *text, bool isCopy, bool lineinfo);
96 bool lineInfo() const;
97 explicit CudaSource(size_t len);
98 explicit CudaSource(char* text);
99 CudaSource();
100 ~CudaSource() override = default;
101 private:
102 bool _lineInfo = false;
103 };
104
105 class CudaBackend final : public Backend {
106 public:
107 class CudaQueue final : public Backend::Queue {
108 public:
109 std::thread::id streamCreationThread;
110 CUstream cuStream;
111 explicit CudaQueue(Backend *backend);
112 void init();
113 void wait() override;
114
115 void release() override;
116
117 void computeStart() override;
118
119 void computeEnd() override;
120
121 void copyToDevice(Buffer *buffer) override;
122
123 void copyFromDevice(Buffer *buffer) override;
124
125 int estimateThreadsPerBlock(int dimensions);
126
127 void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;
128
129 ~CudaQueue() override;
130 };
131
132 class CudaBuffer final : public Buffer {
133 public:
134 CUdeviceptr devicePtr;
135 CudaBuffer(Backend *backend, BufferState *bufferState);
136 ~CudaBuffer() override;
137 };
138
139 class CudaModule final : public CompilationUnit {
140 CUmodule module;
141 CudaSource cudaSource;
142 PtxSource ptxSource;
143 Log log;
144
145 public:
146 class CudaKernel final : public Kernel {
147
148 public:
149 bool setArg(KernelArg *arg) override;
150 bool setArg(KernelArg *arg, Buffer *buffer) override;
151 CudaKernel(Backend::CompilationUnit *program, char* name, CUfunction function);
152 ~CudaKernel() override;
153 static CudaKernel * of(long kernelHandle);
154 static CudaKernel * of(Backend::CompilationUnit::Kernel *kernel);
155
156 CUfunction function;
157 void *argslist[100]{};
158 };
159 CudaModule(Backend *backend, char *cudaSrc, char *log, bool ok, CUmodule module);
160 ~CudaModule() override;
161 static CudaModule * of(long moduleHandle);
162 //static CudaModule * of(CompilationUnit *compilationUnit);
163 Kernel *getKernel(int nameLen, char *name) override;
164 CudaKernel *getCudaKernel(char *name);
165 CudaKernel *getCudaKernel(int nameLen, char *name);
166 bool programOK();
167 };
168
169 private:
170 CUresult initStatus;
171 CUdevice device;
172 CUcontext context;
173 public:
174 void info() override;
175 CudaModule * compile(const CudaSource *cudaSource);
176 CudaModule * compile(const CudaSource &cudaSource);
177 CudaModule * compile(const PtxSource *ptxSource);
178 CudaModule * compile(const PtxSource &ptxSource);
179 static PtxSource *nvcc(const CudaSource *cudaSource);
180 CompilationUnit * compile(int len, char *source) override;
181 void computeStart() override;
182 void computeEnd() override;
183 CudaBuffer * getOrCreateBuffer(BufferState *bufferState) override;
184 bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override;
185
186 explicit CudaBackend(int mode);
187
188 ~CudaBackend() override;
189 static CudaBackend * of(long backendHandle);
190 static CudaBackend * of(Backend *backend);
191 };