1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 #include <sys/wait.h>
 27 #include <chrono>
 28 #include <thread>
 29 #include "cuda_backend.h"
 30 
 31 CudaBackend::CudaQueue::CudaQueue(Backend *backend)
 32         : Backend::Queue(backend),cuStream(),streamCreationThread() {
 33 }
 34 void CudaBackend::CudaQueue::init(){
 35     streamCreationThread = std::this_thread::get_id();
 36     if (backend->config->traceCalls){
 37         std::cout << "init() 0x"
 38                   << " thread=" <<streamCreationThread
 39                   << std::endl;
 40     }
 41 
 42     WHERE{.f=__FILE__ , .l=__LINE__,
 43           .e=cuStreamCreate(&cuStream,CU_STREAM_DEFAULT),
 44           .t= "cuStreamCreate"
 45     }.report();
 46 
 47     if (backend->config->traceCalls){
 48         std::cout << "exiting init() 0x"
 49                   << " custream=" <<std::hex<<streamCreationThread <<std::dec
 50                   << std::endl;
 51     }
 52     }
 53 
 54 void CudaBackend::CudaQueue::wait(){
 55     CUDA_CHECK(cuStreamSynchronize(cuStream), "cuStreamSynchronize");
 56 }
 57 
 58 
 59 void CudaBackend::CudaQueue::computeStart() {
 60     wait(); // should be no-op
 61     release(); // also ;
 62 }
 63 
 64 void CudaBackend::CudaQueue::computeEnd() {
 65 
 66 }
 67 
 68 void CudaBackend::CudaQueue::release() {
 69 
 70 }
 71 
 72 CudaBackend::CudaQueue::~CudaQueue() {
 73     CUDA_CHECK(cuStreamDestroy(cuStream), "cuStreamDestroy");
 74 }
 75 
 76 void CudaBackend::CudaQueue::copyToDevice(Buffer *buffer) {
 77     const auto *cudaBuffer = dynamic_cast<CudaBuffer *>(buffer);
 78     const std::thread::id thread_id = std::this_thread::get_id();
 79     if (thread_id != streamCreationThread){
 80         std::cout << "copyToDevice()  thread=" <<thread_id<< " != "<< streamCreationThread<< std::endl;
 81     }
 82     if (backend->config->traceCalls) {
 83 
 84         std::cout << "copyToDevice() 0x"
 85                 << std::hex<<cudaBuffer->bufferState->length<<std::dec << "/"
 86                 << cudaBuffer->bufferState->length << " "
 87                 << "devptr=" << std::hex<<  static_cast<long>(cudaBuffer->devicePtr) <<std::dec
 88                 << " thread=" <<thread_id
 89                   << std::endl;
 90     }
 91 
 92     CUDA_CHECK(cuMemcpyHtoDAsync(cudaBuffer->devicePtr,
 93                     cudaBuffer->bufferState->ptr,
 94                     cudaBuffer->bufferState->length,
 95                     dynamic_cast<CudaQueue*>(backend->queue)->cuStream), "cuMemcpyHtoDAsync");
 96 }
 97 
 98 void CudaBackend::CudaQueue::copyFromDevice(Buffer *buffer) {
 99     const auto *cudaBuffer = dynamic_cast<CudaBuffer *>(buffer);
100     const std::thread::id thread_id = std::this_thread::get_id();
101     if (thread_id != streamCreationThread){
102         std::cout << "copyFromDevice()  thread=" <<thread_id<< " != "<< streamCreationThread<< std::endl;
103     }
104     if (backend->config->traceCalls) {
105 
106         std::cout << "copyFromDevice() 0x"
107                   << std::hex<<cudaBuffer->bufferState->length<<std::dec << "/"
108                   << cudaBuffer->bufferState->length << " "
109                   << "devptr=" << std::hex<<  static_cast<long>(cudaBuffer->devicePtr) <<std::dec
110                 << " thread=" <<thread_id
111                   << std::endl;
112     }
113 
114     CUDA_CHECK(cuMemcpyDtoHAsync(cudaBuffer->bufferState->ptr,
115                                 cudaBuffer->devicePtr,
116                                 cudaBuffer->bufferState->length,
117                                 dynamic_cast<CudaQueue*>(backend->queue)->cuStream),
118                                 "cuMemcpyDtoHAsync");
119 
120 }
121 
122 // TODO: Improve heuristics to decide a better block size, if possible.
123 // The following is just a rough number to fit into a modern NVIDIA GPU.
124 int CudaBackend::CudaQueue::estimateThreadsPerBlock(int dimensions) {
125     switch (dimensions) {
126         case 1: return 256;
127         case 2: return 16;
128         case 3: return 16;
129         default: return 1;
130     }
131 }
132 
133 void CudaBackend::CudaQueue::dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) {
134     const auto cudaKernel = dynamic_cast<CudaModule::CudaKernel *>(kernel);
135 
136     int threadsPerBlockX;
137     int threadsPerBlockY = 1;
138     int threadsPerBlockZ = 1;
139 
140     // The local and global mesh dimensions match by design from the Java APIs
141     const int dimensions = kernelContext->globalMesh.dimensions;
142     if (kernelContext -> localMesh.maxX > 0) {
143         threadsPerBlockX = kernelContext -> localMesh.maxX;
144     } else {
145         threadsPerBlockX = estimateThreadsPerBlock(dimensions);
146     }
147     if (kernelContext-> localMesh.maxY > 0) {
148         threadsPerBlockY = kernelContext-> localMesh.maxY;
149     } else if (dimensions > 1) {
150         threadsPerBlockY = estimateThreadsPerBlock(dimensions);
151     }
152     if (kernelContext-> localMesh.maxZ > 0) {
153         threadsPerBlockZ = kernelContext-> localMesh.maxZ;
154     } else if (dimensions > 2) {
155         threadsPerBlockZ = estimateThreadsPerBlock(dimensions);
156     }
157 
158     int blocksPerGridX = (kernelContext->globalMesh.maxX + threadsPerBlockX - 1) / threadsPerBlockX;
159     int blocksPerGridY = 1;
160     int blocksPerGridZ = 1;
161 
162     if (dimensions > 1) {
163         blocksPerGridY = (kernelContext->globalMesh.maxY + threadsPerBlockY - 1) / threadsPerBlockY;
164     }
165     if (dimensions > 2) {
166         blocksPerGridZ = (kernelContext->globalMesh.maxZ + threadsPerBlockZ - 1) / threadsPerBlockZ;
167     }
168 
169     // Enable debug information with trace. Use HAT=INFO
170     if (backend->config->info) {
171         std::cout << "Dispatching the CUDA kernel" << std::endl;
172         std::cout << "   \\_ BlocksPerGrid   = [" << blocksPerGridX << "," << blocksPerGridY << "," << blocksPerGridZ << "]" << std::endl;
173         std::cout << "   \\_ ThreadsPerBlock = [" << threadsPerBlockX << "," << threadsPerBlockY << "," << threadsPerBlockZ << "]" << std::endl;
174     }
175 
176     const std::thread::id thread_id = std::this_thread::get_id();
177     if (thread_id != streamCreationThread){
178         std::cout << "dispatch()  thread=" <<thread_id<< " != "<< streamCreationThread<< std::endl;
179     }
180 
181     const auto status = cuLaunchKernel(cudaKernel->function, //
182                                  blocksPerGridX, blocksPerGridY, blocksPerGridZ, //
183                                  threadsPerBlockX, threadsPerBlockY, threadsPerBlockZ, //
184                                  0, //
185                                  cuStream, //
186                                  cudaKernel->argslist, //
187                                  nullptr);
188 
189     CUDA_CHECK(status, "cuLaunchKernel");
190 }