1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define PTX_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <cuda.h> 53 #include <builtin_types.h> 54 55 #define CUDA_TYPES 56 57 #include "shared.h" 58 59 #include <fstream> 60 61 #include<vector> 62 63 class Ptx { 64 public: 65 size_t len; 66 char *text; 67 68 Ptx(size_t len); 69 70 ~Ptx(); 71 72 static Ptx *nvcc(const char *ptxSource, size_t len); 73 }; 74 75 class PtxBackend : public Backend { 76 public: 77 class PtxConfig : public Backend::Config { 78 public: 79 boolean gpu; 80 }; 81 82 class PtxProgram : public Backend::Program { 83 class PtxKernel : public Backend::Program::Kernel { 84 class PtxBuffer : public Backend::Program::Kernel::Buffer { 85 public: 86 CUdeviceptr devicePtr; 87 88 PtxBuffer(Backend::Program::Kernel *kernel, Arg_s *arg); 89 90 void copyToDevice(); 91 92 void copyFromDevice(); 93 94 virtual ~PtxBuffer(); 95 }; 96 97 private: 98 CUfunction function; 99 cudaStream_t cudaStream; 100 public: 101 PtxKernel(Backend::Program *program, char* name, CUfunction function); 102 103 ~PtxKernel() override; 104 105 long ndrange( void *argArray); 106 }; 107 108 private: 109 CUmodule module; 110 Ptx *ptx; 111 112 public: 113 PtxProgram(Backend *backend, BuildInfo *buildInfo, Ptx *ptx, CUmodule module); 114 115 ~PtxProgram(); 116 117 long getKernel(int nameLen, char *name); 118 119 bool programOK(); 120 }; 121 122 private: 123 CUdevice device; 124 CUcontext context; 125 public: 126 127 PtxBackend(PtxConfig *config, int configSchemaLen, char *configSchema); 128 129 PtxBackend(); 130 131 ~PtxBackend(); 132 133 int getMaxComputeUnits(); 134 135 void info(); 136 137 long compileProgram(int len, char *source); 138 139 }; 140