1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 #pragma once
26 #define HIP_TYPES
27 #ifdef __APPLE__
28
29 #define LongUnsignedNewline "%llu\n"
30 #define Size_tNewline "%lu\n"
31 #define LongHexNewline "(0x%llx)\n"
32 #define alignedMalloc(size, alignment) memalign(alignment, size)
33 #define SNPRINTF snprintf
34 #else
35
36 #include <malloc.h>
37
38 #define LongHexNewline "(0x%lx)\n"
39 #define LongUnsignedNewline "%lu\n"
40 #define Size_tNewline "%lu\n"
41 #if defined (_WIN32)
42 #include "windows.h"
43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
44 #define SNPRINTF _snprintf
45 #else
46 #define alignedMalloc(size, alignment) memalign(alignment, size)
47 #define SNPRINTF snprintf
48 #endif
49 #endif
50
51 #include <iostream>
52 #include <hip/hip_runtime.h>
53 //#include <builtin_types.h>
54
55 #include "shared.h"
56
57 #include <fstream>
58
59 #include<vector>
60 #include <thread>
61
62 /*
63 struct WHERE{
64 const char* f;
65 int l;
66 cudaError_enum e;
67 const char* t;
68 void report() const{
69 if (e == CUDA_SUCCESS){
70 // std::cout << t << " OK at " << f << " line " << l << std::endl;
71 }else {
72 const char *buf;
73 cuGetErrorName(e, &buf);
74 std::cerr << t << " CUDA error = " << e << " " << buf <<std::endl<< " " << f << " line " << l << std::endl;
75 exit(-1);
76 }
77 }
78 };
79
80 */
81 class PtxSource: public Text {
82 public:
83 PtxSource();
84 PtxSource(size_t len);
85 PtxSource(size_t len, char *text);
86 PtxSource(char *text);
87 ~PtxSource() = default;
88 // static PtxSource *nvcc(const char *cudaSource, size_t len);
89 };
90
91 class HipSource: public Text {
92 public:
93 HipSource();
94 HipSource(size_t len);
95 HipSource(size_t len, char *text);
96 HipSource(char *text);
97 ~HipSource() = default;
98
99 };
100
101 class HipBackend : public Backend {
102 public:
103 class HipQueue: public Backend::Queue {
104 public:
105 std::thread::id streamCreationThread;
106 //CUstream cuStream;
107 hipStream_t cuStream;
108 explicit HipQueue(Backend *backend);
109 void init();
110 void wait() override;
111
112 void release() override;
113
114 void computeStart() override;
115
116 void computeEnd() override;
117
118 void copyToDevice(Buffer *buffer) override;
119
120 void copyFromDevice(Buffer *buffer) override;
121
122 void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;
123
124 ~HipQueue() override;
125
126 };
127
128 class HipBuffer : public Backend::Buffer {
129 public:
130 //CUdeviceptr devicePtr;
131 hipDevice_t devicePtr;
132 HipBuffer(Backend *backend, BufferState *bufferState);
133 ~HipBuffer() override;
134 };
135
136 class HipProgram : public Backend::CompilationUnit {
137 class HipKernel : public Backend::CompilationUnit::Kernel {
138
139
140 private:
141 hipFunction_t kernel;
142 hipStream_t hipStream;
143 public:
144 HipKernel(Backend::CompilationUnit *program, char* name, hipFunction_t kernel);
145
146 ~HipKernel() override;
147
148 //long ndrange( void *argArray);
149 };
150
151 private:
152 hipModule_t module;
153 HipSource hipSource;
154 PtxSource ptxSource;
155 Log log;
156
157 public:
158 HipProgram(Backend *backend, Backend::CompilationUnit::BuildInfo *buildInfo, hipModule_t module);
159 ~HipProgram();
160
161 long getHipKernel(char *name);
162 long getHipKernel(int nameLen, char *name);
163
164 bool programOK();
165 };
166
167 private:
168 hipDevice_t device;
169 hipCtx_t context;
170 public:
171 void info();
172
173 HipBackend(int mode);
174 HipBackend();
175 ~HipBackend();
176
177 int getMaxComputeUnits();
178
179 };
180