1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define HIP_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <hip/hip_runtime.h> 53 #include <builtin_types.h> 54 55 #include "shared.h" 56 57 #include <fstream> 58 59 #include<vector> 60 #include <thread> 61 62 /* 63 struct WHERE{ 64 const char* f; 65 int l; 66 cudaError_enum e; 67 const char* t; 68 void report() const{ 69 if (e == CUDA_SUCCESS){ 70 // std::cout << t << " OK at " << f << " line " << l << std::endl; 71 }else { 72 const char *buf; 73 cuGetErrorName(e, &buf); 74 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl; 75 exit(-1); 76 } 77 } 78 }; 79 80 */ 81 class PtxSource: public Text { 82 public: 83 PtxSource(); 84 PtxSource(size_t len); 85 PtxSource(size_t len, char *text); 86 PtxSource(char *text); 87 ~PtxSource() = default; 88 // static PtxSource *nvcc(const char *cudaSource, size_t len); 89 }; 90 91 class HipSource: public Text { 92 public: 93 HipSource(); 94 HipSource(size_t len); 95 HipSource(size_t len, char *text); 96 HipSource(char *text); 97 ~HipSource() = default; 98 99 }; 100 101 class HipBackend : public Backend { 102 public: 103 class HipQueue: public Backend::Queue { 104 public: 105 std::thread::id streamCreationThread; 106 CUstream cuStream; 107 HipQueue(Backend *backend); 108 void init(); 109 void wait() override; 110 111 void release() override; 112 113 void computeStart() override; 114 115 void computeEnd() override; 116 117 void copyToDevice(Buffer *buffer) override; 118 119 void copyFromDevice(Buffer *buffer) override; 120 121 virtual void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override; 122 123 virtual ~HipQueue(); 124 125 }; 126 127 class HipBuffer : public Backend::Buffer { 128 public: 129 CUdeviceptr devicePtr; 130 HipBuffer(Backend *backend, BufferState *bufferState); 131 virtual ~CudaBuffer(); 132 }; 133 134 class HipProgram : public Backend::CompilationUnit { 135 class HipKernel : public Backend::CompilationUnit::Kernel { 136 137 138 private: 139 hipFunction_t kernel; 140 hipStream_t hipStream; 141 public: 142 HIPKernel(Backend::CompilationUnit *program, char* name, hipFunction_t kernel); 143 144 ~HIPKernel() override; 145 146 long ndrange( void *argArray); 147 }; 148 149 private: 150 HipModule_t module; 151 HipSource hipSource; 152 PtxSource ptxSource; 153 Log log; 154 155 public: 156 HIPProgram(Backend *backend, BuildInfo *buildInfo, hipModule_t module); 157 ~HIPProgram(); 158 159 long getHipKernel(char *name); 160 long getHipKernel(int nameLen, char *name); 161 162 bool programOK(); 163 }; 164 165 private: 166 hipDevice_t device; 167 hipCtx_t context; 168 public: 169 void info(); 170 171 HIPBackend(in mode); 172 HIPBackend(); 173 ~HIPBackend(); 174 175 int getMaxComputeUnits(); 176 177 }; 178