1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 #include <fstream>
 26 #define shared_cpp
 27 
 28 #include "shared.h"
 29 
 30 #define INFO 0
 31 
 32 
 33 void hexdump(void *ptr, int buflen) {
 34     auto *buf = static_cast<unsigned char *>(ptr);
 35     int i, j;
 36     for (i = 0; i < buflen; i += 16) {
 37         printf("%06x: ", i);
 38         for (j = 0; j < 16; j++)
 39             if (i + j < buflen)
 40                 printf("%02x ", buf[i + j]);
 41             else
 42                 printf("   ");
 43         printf(" ");
 44         for (j = 0; j < 16; j++)
 45             if (i + j < buflen)
 46                 printf("%c", isprint(buf[i + j]) ? buf[i + j] : '.');
 47         printf("\n");
 48     }
 49 }
 50 
 51 void Sled::show(std::ostream &out, void *argArray) {
 52     ArgSled argSled(static_cast<ArgArray_s *>(argArray));
 53     for (int i = 0; i < argSled.argc(); i++) {
 54         KernelArg *arg = argSled.arg(i);
 55         switch (arg->variant) {
 56             case '&': {
 57                 out << "Buf: of " << arg->value.buffer.sizeInBytes << " bytes " << std::endl;
 58                 break;
 59             }
 60             case 'B': {
 61                 out << "S8:" << arg->value.s8 << std::endl;
 62                 break;
 63             }
 64             case 'Z': {
 65                 out << "Z:" << arg->value.z1 << std::endl;
 66                 break;
 67             }
 68             case 'C': {
 69                 out << "U16:" << arg->value.u16 << std::endl;
 70                 break;
 71             }
 72             case 'S': {
 73                 out << "S16:" << arg->value.s16 << std::endl;
 74                 break;
 75             }
 76             case 'I': {
 77                 out << "S32:" << arg->value.s32 << std::endl;
 78                 break;
 79             }
 80             case 'F': {
 81                 out << "F32:" << arg->value.f32 << std::endl;
 82                 break;
 83             }
 84             case 'J': {
 85                 out << "S64:" << arg->value.s64 << std::endl;
 86                 break;
 87             }
 88             case 'D': {
 89                 out << "F64:" << arg->value.f64 << std::endl;
 90                 break;
 91             }
 92             default: {
 93                 std::cerr << "unexpected variant (shared.cpp) '" << static_cast<char>(arg->variant) << "'" << std::endl;
 94                 exit(1);
 95             }
 96         }
 97     }
 98     out << "schema len = " << argSled.schemaLen() << std::endl;
 99 
100     out << "schema = " << argSled.schema() << std::endl;
101 }
102 
103 
104 extern "C" void showDeviceInfo(long backendHandle) {
105         std::cout << "DEBUGGGGGGG through backendHandle to backend.showDeviceInfo()" << std::endl;
106     if (INFO) {
107         std::cout << "trampolining through backendHandle to backend.showDeviceInfo()" << std::endl;
108     }
109     auto *backend = reinterpret_cast<Backend *>(backendHandle);
110     backend->showDeviceInfo();
111 }
112 
113 extern "C" void computeStart(long backendHandle) {
114     if (INFO) {
115         std::cout << "trampolining through backendHandle to backend.computeStart()" << std::endl;
116     }
117     auto *backend = reinterpret_cast<Backend *>(backendHandle);
118     backend->computeStart();
119 }
120 
121 extern "C" void computeEnd(long backendHandle) {
122     if (INFO) {
123         std::cout << "trampolining through backendHandle to backend.computeEnd()" << std::endl;
124     }
125     auto *backend = reinterpret_cast<Backend *>(backendHandle);
126     backend->computeEnd();
127 }
128 
129 extern "C" void releaseBackend(long backendHandle) {
130     auto *backend = reinterpret_cast<Backend *>(backendHandle);
131     delete backend;
132 }
133 
134 extern "C" long compile(long backendHandle, int len, char *source) {
135     if (INFO) {
136         std::cout << "trampolining through backendHandle to backend.compile() "
137                 << std::hex << backendHandle << std::dec << std::endl;
138     }
139     auto *backend = reinterpret_cast<Backend *>(backendHandle);
140     long compilationUnitHandle = reinterpret_cast<long>(backend->compile(len, source));
141     if (INFO) {
142         std::cout << "compilationUnitHandle = " << std::hex << compilationUnitHandle << std::dec << std::endl;
143     }
144     return compilationUnitHandle;
145 }
146 
147 extern "C" long getKernel(long compilationUnitHandle, int nameLen, char *name) {
148     if (INFO) {
149         std::cout << "trampolining through programHandle to compilationUnit.getKernel()"
150                 << std::hex << compilationUnitHandle << std::dec << std::endl;
151     }
152     auto compilationUnit = reinterpret_cast<Backend::CompilationUnit *>(compilationUnitHandle);
153     return reinterpret_cast<long>(compilationUnit->getKernel(nameLen, name));
154 }
155 
156 extern "C" long ndrange(long kernelHandle, void *argArray) {
157     if (INFO) {
158         std::cout << "trampolining through kernelHandle to kernel.ndrange(...) " << std::endl;
159     }
160     auto kernel = reinterpret_cast<Backend::CompilationUnit::Kernel *>(kernelHandle);
161     kernel->ndrange(argArray);
162     return (long) 0;
163 }
164 
165 extern "C" void releaseKernel(long kernelHandle) {
166     if (INFO) {
167         std::cout << "trampolining through to releaseKernel " << std::endl;
168     }
169     auto kernel = reinterpret_cast<Backend::CompilationUnit::Kernel *>(kernelHandle);
170     delete kernel;
171 }
172 
173 extern "C" void releaseCompilationUnit(long compilationUnitHandle) {
174     if (INFO) {
175         std::cout << "trampolining through to releaseCompilationUnit " << std::endl;
176     }
177     auto compilationUnit = reinterpret_cast<Backend::CompilationUnit *>(compilationUnitHandle);
178     delete compilationUnit;
179 }
180 
181 extern "C" bool compilationUnitOK(long compilationUnitHandle) {
182     if (INFO) {
183         std::cout << "trampolining through to compilationUnitHandleOK " << std::endl;
184     }
185     auto compilationUnit = reinterpret_cast<Backend::CompilationUnit *>(compilationUnitHandle);
186     return compilationUnit->compilationUnitOK();
187 }
188 
189 extern "C" bool getBufferFromDeviceIfDirty(long backendHandle, long memorySegmentHandle, long memorySegmentLength) {
190     if (INFO) {
191         std::cout << "trampolining through to getBuffer " << std::endl;
192     }
193     auto backend = reinterpret_cast<Backend *>(backendHandle);
194     auto memorySegment = reinterpret_cast<void *>(memorySegmentHandle);
195     return backend->getBufferFromDeviceIfDirty(memorySegment, memorySegmentLength);
196 }
197 
198 
199 Backend::Config::Config(int configBits):BasicConfig(configBits) {
200 
201 }
202 
203 Backend::Config::~Config() = default;
204 
205 Backend::Queue::Queue(Backend *backend)
206     : backend(backend) {
207 }
208 
209 Backend::Queue::~Queue() = default;
210 
211 Text::Text(size_t len, char *text, bool isCopy)
212     : len(len), text(text), isCopy(isCopy) {
213     // std::cout << "in Text len="<<len<<" isCopy="<<isCopy << std::endl;
214 }
215 
216 Text::Text(char *text, bool isCopy)
217     : len(std::strlen(text)), text(text), isCopy(isCopy) {
218     // std::cout << "in Text len="<<len<<" isCopy="<<isCopy << std::endl;
219 }
220 
221 Text::Text(size_t len)
222     : len(len), text(len > 0 ? new char[len] : nullptr), isCopy(true) {
223     //  std::cout << "in Text len="<<len<<" isCopy="<<isCopy << std::endl;
224 }
225 
226 void Text::write(const std::string &filename) const {
227     std::ofstream out;
228     out.open(filename, std::ofstream::trunc);
229     out.write(text, len);
230     out.close();
231 }
232 
233 void Text::read(const std::string &filename) {
234     if (isCopy && text) {
235         delete[] text;
236     }
237     text = nullptr;
238     isCopy = false;
239     // std::cout << "reading from " << filename << std::endl;
240 
241     std::ifstream ptxStream;
242     ptxStream.open(filename);
243 
244 
245     ptxStream.seekg(0, std::ios::end);
246     len = ptxStream.tellg();
247     ptxStream.seekg(0, std::ios::beg);
248 
249     if (len > 0) {
250         text = new char[len];
251         isCopy = true;
252         //std::cerr << "about to read  " << len << std::endl;
253         ptxStream.read(text, len);
254         ptxStream.close();
255         //std::cerr << "read  " << len << std::endl;
256         text[len - 1] = '\0';
257         //std::cerr << "read text " << text << std::endl;
258     }
259 }
260 
261 Text::~Text() {
262     if (isCopy && text) {
263         delete[] text;
264     }
265     text = nullptr;
266     isCopy = false;
267     len = 0;
268 }
269 
270 Log::Log(const size_t len)
271     : Text(len) {
272 }
273 
274 Log::Log(char *text)
275     : Text(text, false) {
276 }
277 
278 long Backend::CompilationUnit::Kernel::ndrange(void *argArray) {
279     if (compilationUnit->backend->config->traceCalls) {
280         std::cout << "kernelContext(\"" << name << "\"){" << std::endl;
281     }
282     ArgSled argSled(static_cast<ArgArray_s *>(argArray));
283     auto *profilableQueue = dynamic_cast<ProfilableQueue *>(compilationUnit->backend->queue);
284     if (profilableQueue != nullptr) {
285         profilableQueue->marker(ProfilableQueue::EnterKernelDispatchBits, name);
286     }
287     if (compilationUnit->backend->config->trace) {
288         Sled::show(std::cout, argArray);
289     }
290     KernelContext *kernelContext = nullptr;
291     for (int i = 0; i < argSled.argc(); i++) {
292         KernelArg *arg = argSled.arg(i);
293         switch (arg->variant) {
294             case '&': {
295                 if (arg->idx == 0) {
296                     // This does not have to be the case all the time. We should be able to pass the kernel context in any argument we want.
297                     kernelContext = static_cast<KernelContext *>(arg->value.buffer.memorySegment);
298                 }
299                 bool readAccessor  = arg->value.buffer.access == RO_BYTE || arg->value.buffer.access == RW_BYTE || arg->value.buffer.access == UNKNOWN_BYTE;
300                 if (compilationUnit->backend->config->trace) {
301                     std::cout << "arg[" << i << "] = " << std::hex << (int) (arg->value.buffer.access);
302                     switch (arg->value.buffer.access) {
303                         case RO_BYTE:
304                             std::cout << " RO";
305                             break;
306                         case WO_BYTE:
307                             std::cout << " WO";
308                             break;
309                         case RW_BYTE:
310                             std::cout << " RW";
311                             break;
312                     }
313                     std::cout << std::endl;
314                 }
315 
316                 BufferState *bufferState = BufferState::of(arg);
317 
318                 Buffer *buffer = compilationUnit->backend->getOrCreateBuffer(bufferState);
319 
320                 bool kernelReadsFromThisArg =  arg->value.buffer.access == RW_BYTE
321                                             || arg->value.buffer.access == RO_BYTE;
322 
323                 bool copyToDevice = readAccessor;
324                 if (!compilationUnit->backend->config->alwaysCopy) {
325                     copyToDevice = (bufferState->state == BufferState::NEW_STATE)
326                                      || ((bufferState->state == BufferState::HOST_OWNED));
327                 }
328 
329                 if (compilationUnit->backend->config->showWhy) {
330                     std::cout << "config.alwaysCopy=" << compilationUnit->backend->config->alwaysCopy
331                             << " | arg.RW=" << (arg->value.buffer.access == RW_BYTE)
332                             << " | arg.RO=" << (arg->value.buffer.access == RO_BYTE)
333                             << " | kernel.needsToRead=" << kernelReadsFromThisArg
334                             << " | Buffer state = " << BufferState::stateNames[bufferState->state]
335                             << " so "
336                             << std::endl;
337                 }
338                 if (copyToDevice) {
339                     compilationUnit->backend->queue->copyToDevice(buffer);
340                     bufferState->state = BufferState::DEVICE_OWNED;
341                     if (compilationUnit->backend->config->traceCopies) {
342                         std::cout << "copying arg " << arg->idx << " host->device " << std::endl;
343                     }
344                 } else {
345                     if (compilationUnit->backend->config->traceSkippedCopies) {
346                         std::cout << "NOT copying arg " << arg->idx << " host->device " << std::endl;
347                     }
348                 }
349                 setArg(arg, buffer);
350                 if (compilationUnit->backend->config->trace) {
351                     std::cout << "set buffer arg " << arg->idx << std::endl;
352                 }
353                 break;
354             }
355             case 'B':
356             case 'S':
357             case 'C':
358             case 'I':
359             case 'F':
360             case 'J':
361             case 'D': {
362                 setArg(arg);
363                 if (compilationUnit->backend->config->trace) {
364                     std::cerr << "set " << arg->variant << " " << arg->idx << std::endl;
365                 }
366                 break;
367             }
368             default: {
369                 std::cerr << "unexpected variant setting args in OpenCLkernel::kernelContext " << (char) arg->variant <<
370                         std::endl;
371                 exit(1);
372             }
373         }
374     }
375 
376     if (kernelContext == nullptr) {
377         std::cerr << "Looks like we recieved a kernel dispatch with xero args kernel='" << name << "'" << std::endl;
378         exit(1);
379     }
380 
381     if (compilationUnit->backend->config->trace) {
382         std::cout << "kernelContext = <" << kernelContext->gsx << "," << kernelContext->gsy << "," << kernelContext->gsz << ">" << std::endl;
383     }
384 
385     compilationUnit->backend->queue->dispatch(kernelContext, this);
386 
387     for (int i = 0; i < argSled.argc(); i++) {
388         // note i = 1... we never need to copy back the KernelContext
389         KernelArg *arg = argSled.arg(i);
390         if (arg->variant == '&') {
391             BufferState *bufferState = BufferState::of(arg);
392 
393             bool kernelWroteToThisArg = (arg->value.buffer.access == WO_BYTE) || (arg->value.buffer.access == RW_BYTE);
394             if (compilationUnit->backend->config->showWhy) {
395                 std::cout <<
396                         "config.alwaysCopy=" << compilationUnit->backend->config->alwaysCopy
397                         << " | arg.WO=" << (arg->value.buffer.access == WO_BYTE)
398                         << " | arg.RW=" << (arg->value.buffer.access == RW_BYTE)
399                         << " | kernel.wroteToThisArg=" << kernelWroteToThisArg
400                         << "Buffer state = " << BufferState::stateNames[bufferState->state]
401                         << " so "
402                         << std::endl;
403             }
404 
405             auto *buffer = static_cast<Buffer *>(bufferState->vendorPtr);
406             if (kernelWroteToThisArg && compilationUnit->backend->config->alwaysCopy) {
407                 compilationUnit->backend->queue->copyFromDevice(buffer);
408                 bufferState->state = BufferState::HOST_OWNED;
409                 if (compilationUnit->backend->config->traceCopies || compilationUnit->backend->config->traceEnqueues) {
410                     std::cout << "copying arg " << arg->idx << " device->host " << std::endl;
411                 }
412             } else {
413                 if (compilationUnit->backend->config->traceSkippedCopies) {
414                     std::cout << "NOT copying arg " << arg->idx << " device->host " << std::endl;
415                 }
416             }
417         }
418     }
419     if (profilableQueue != nullptr) {
420         profilableQueue->marker(Backend::ProfilableQueue::LeaveKernelDispatchBits, name);
421     }
422     compilationUnit->backend->queue->wait();
423     compilationUnit->backend->queue->release();
424     if (compilationUnit->backend->config->traceCalls) {
425         std::cout << "\"" << name << "\"}" << std::endl;
426     }
427     return 0;
428 }