1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define PTX_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <cuda.h> 53 #include <builtin_types.h> 54 55 #define CUDA_TYPES 56 57 #include "shared.h" 58 59 #include <fstream> 60 61 #include<vector> 62 63 class Ptx { 64 public: 65 size_t len; 66 char *text; 67 68 Ptx(size_t len); 69 70 ~Ptx(); 71 72 static Ptx *nvcc(const char *ptxSource, size_t len); 73 }; 74 75 class PtxBackend : public Backend { 76 public: 77 78 class PtxProgram : public Backend::Program { 79 class PtxKernel : public Backend::Program::Kernel { 80 class PtxBuffer : public Backend::Program::Kernel::Buffer { 81 public: 82 CUdeviceptr devicePtr; 83 84 PtxBuffer(Backend::Program::Kernel *kernel, Arg_s *arg); 85 86 void copyToDevice(); 87 88 void copyFromDevice(); 89 90 virtual ~PtxBuffer(); 91 }; 92 93 private: 94 CUfunction function; 95 cudaStream_t cudaStream; 96 public: 97 PtxKernel(Backend::Program *program, char* name, CUfunction function); 98 99 ~PtxKernel() override; 100 101 long ndrange( void *argArray); 102 }; 103 104 private: 105 CUmodule module; 106 Ptx *ptx; 107 108 public: 109 PtxProgram(Backend *backend, BuildInfo *buildInfo, Ptx *ptx, CUmodule module); 110 111 ~PtxProgram(); 112 113 long getKernel(int nameLen, char *name); 114 115 bool programOK(); 116 }; 117 118 private: 119 CUdevice device; 120 CUcontext context; 121 public: 122 123 PtxBackend(int mode); 124 125 ~PtxBackend(); 126 127 int getMaxComputeUnits(); 128 129 void info(); 130 131 long compileProgram(int len, char *source); 132 bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength); 133 134 }; 135 extern "C" long getPtxBackend(int mode); 136