1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 #pragma once 26 #define HIP_TYPES 27 #ifdef __APPLE__ 28 29 #define LongUnsignedNewline "%llu\n" 30 #define Size_tNewline "%lu\n" 31 #define LongHexNewline "(0x%llx)\n" 32 #define alignedMalloc(size, alignment) memalign(alignment, size) 33 #define SNPRINTF snprintf 34 #else 35 36 #include <malloc.h> 37 38 #define LongHexNewline "(0x%lx)\n" 39 #define LongUnsignedNewline "%lu\n" 40 #define Size_tNewline "%lu\n" 41 #if defined (_WIN32) 42 #include "windows.h" 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment) 44 #define SNPRINTF _snprintf 45 #else 46 #define alignedMalloc(size, alignment) memalign(alignment, size) 47 #define SNPRINTF snprintf 48 #endif 49 #endif 50 51 #include <iostream> 52 #include <hip/hip_runtime.h> 53 //#include <builtin_types.h> 54 55 #include "shared.h" 56 57 #include <fstream> 58 59 #include<vector> 60 #include <thread> 61 62 /* 63 struct WHERE{ 64 const char* f; 65 int l; 66 cudaError_enum e; 67 const char* t; 68 void report() const{ 69 if (e == CUDA_SUCCESS){ 70 // std::cout << t << " OK at " << f << " line " << l << std::endl; 71 }else { 72 const char *buf; 73 cuGetErrorName(e, &buf); 74 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl; 75 exit(-1); 76 } 77 } 78 }; 79 80 */ 81 class PtxSource: public Text { 82 public: 83 PtxSource(); 84 PtxSource(size_t len); 85 PtxSource(size_t len, char *text); 86 PtxSource(char *text); 87 ~PtxSource() = default; 88 // static PtxSource *nvcc(const char *cudaSource, size_t len); 89 }; 90 91 class HipSource: public Text { 92 public: 93 HipSource(); 94 HipSource(size_t len); 95 HipSource(size_t len, char *text); 96 HipSource(char *text); 97 ~HipSource() = default; 98 99 }; 100 101 class HipBackend : public Backend { 102 public: 103 class HipQueue: public Backend::Queue { 104 public: 105 std::thread::id streamCreationThread; 106 //CUstream cuStream; 107 hipStream_t cuStream; 108 explicit HipQueue(Backend *backend); 109 void init(); 110 void wait() override; 111 112 void release() override; 113 114 void computeStart() override; 115 116 void computeEnd() override; 117 118 void copyToDevice(Buffer *buffer) override; 119 120 void copyFromDevice(Buffer *buffer) override; 121 122 void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override; 123 124 ~HipQueue() override; 125 126 }; 127 128 class HipBuffer : public Backend::Buffer { 129 public: 130 //CUdeviceptr devicePtr; 131 hipDevice_t devicePtr; 132 HipBuffer(Backend *backend, BufferState *bufferState); 133 ~HipBuffer() override; 134 }; 135 136 class HipProgram : public Backend::CompilationUnit { 137 class HipKernel : public Backend::CompilationUnit::Kernel { 138 139 140 private: 141 hipFunction_t kernel; 142 hipStream_t hipStream; 143 public: 144 HipKernel(Backend::CompilationUnit *program, char* name, hipFunction_t kernel); 145 146 ~HipKernel() override; 147 148 //long ndrange( void *argArray); 149 }; 150 151 private: 152 hipModule_t module; 153 HipSource hipSource; 154 PtxSource ptxSource; 155 Log log; 156 157 public: 158 HipProgram(Backend *backend, Backend::CompilationUnit::BuildInfo *buildInfo, hipModule_t module); 159 ~HipProgram(); 160 161 long getHipKernel(char *name); 162 long getHipKernel(int nameLen, char *name); 163 164 bool programOK(); 165 }; 166 167 private: 168 hipDevice_t device; 169 hipCtx_t context; 170 public: 171 void info(); 172 173 HipBackend(int mode); 174 HipBackend(); 175 ~HipBackend(); 176 177 int getMaxComputeUnits(); 178 179 }; 180