1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define CUDA_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <cuda.h> 53 #include <builtin_types.h> 54 55 #define CUDA_TYPES 56 57 #include "shared.h" 58 59 #include <fstream> 60 61 #include<vector> 62 63 //extern void __checkCudaErrors(CUresult err, const char *file, const int line); 64 65 //#define checkCudaErrors(err) __checkCudaErrors (err, __FILE__, __LINE__) 66 67 class Ptx { 68 public: 69 size_t len; 70 char *text; 71 72 Ptx(size_t len); 73 74 ~Ptx(); 75 76 static Ptx *nvcc(const char *cudaSource, size_t len); 77 }; 78 79 class CudaBackend : public Backend { 80 public: 81 class CudaProgram : public Backend::Program { 82 class CudaKernel : public Backend::Program::Kernel { 83 class CudaBuffer : public Backend::Program::Kernel::Buffer { 84 public: 85 CUdeviceptr devicePtr; 86 87 CudaBuffer(Backend::Program::Kernel *kernel, Arg_s *arg); 88 89 void copyToDevice(); 90 91 void copyFromDevice(); 92 93 virtual ~CudaBuffer(); 94 }; 95 96 private: 97 CUfunction function; 98 cudaStream_t cudaStream; 99 public: 100 CudaKernel(Backend::Program *program, char* name, CUfunction function); 101 102 ~CudaKernel() override; 103 104 long ndrange( void *argArray); 105 }; 106 107 private: 108 CUmodule module; 109 Ptx *ptx; 110 111 public: 112 CudaProgram(Backend *backend, BuildInfo *buildInfo, Ptx *ptx, CUmodule module); 113 114 ~CudaProgram(); 115 116 long getKernel(int nameLen, char *name); 117 118 bool programOK(); 119 }; 120 121 private: 122 CUdevice device; 123 CUcontext context; 124 public: 125 126 CudaBackend(int mode); 127 128 CudaBackend(); 129 130 ~CudaBackend(); 131 132 int getMaxComputeUnits(); 133 134 void info(); 135 136 long compileProgram(int len, char *source); 137 138 }; 139 extern "C" long getCudaBackend(int mode); 140