1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 #pragma once
 26 #define CUDA_TYPES
 27 #ifdef __APPLE__
 28 
 29 #define LongUnsignedNewline "%llu\n"
 30 #define Size_tNewline "%lu\n"
 31 #define LongHexNewline "(0x%llx)\n"
 32 #define alignedMalloc(size, alignment) memalign(alignment, size)
 33 #define SNPRINTF snprintf
 34 #else
 35 
 36 #include <malloc.h>
 37 
 38 #define LongHexNewline "(0x%lx)\n"
 39 #define LongUnsignedNewline "%lu\n"
 40 #define Size_tNewline "%lu\n"
 41 #if defined (_WIN32)
 42 #include "windows.h"
 43 #define alignedMalloc(size, alignment) _aligned_malloc(size, alignment)
 44 #define SNPRINTF _snprintf
 45 #else
 46 #define alignedMalloc(size, alignment) memalign(alignment, size)
 47 #define SNPRINTF  snprintf
 48 #endif
 49 #endif
 50 
 51 #include <iostream>
 52 #include <cuda.h>
 53 #include <builtin_types.h>
 54 
 55 #define CUDA_TYPES
 56 
 57 #include "shared.h"
 58 
 59 #include <fstream>
 60 
 61 #include<vector>
 62 
 63 //extern void __checkCudaErrors(CUresult err, const char *file, const int line);
 64 
 65 //#define checkCudaErrors(err)  __checkCudaErrors (err, __FILE__, __LINE__)
 66 
 67 class Ptx {
 68 public:
 69     size_t len;
 70     char *text;
 71 
 72     Ptx(size_t len);
 73 
 74     ~Ptx();
 75 
 76     static Ptx *nvcc(const char *cudaSource, size_t len);
 77 };
 78 
 79 class CudaBackend : public Backend {
 80 public:
 81     class CudaConfig : public Backend::Config {
 82     public:
 83         boolean gpu;
 84     };
 85 
 86     class CudaProgram : public Backend::Program {
 87         class CudaKernel : public Backend::Program::Kernel {
 88             class CudaBuffer : public Backend::Program::Kernel::Buffer {
 89             public:
 90                 CUdeviceptr devicePtr;
 91 
 92                 CudaBuffer(Backend::Program::Kernel *kernel, Arg_s *arg);
 93 
 94                 void copyToDevice();
 95 
 96                 void copyFromDevice();
 97 
 98                 virtual ~CudaBuffer();
 99             };
100 
101         private:
102             CUfunction function;
103             cudaStream_t cudaStream;
104         public:
105             CudaKernel(Backend::Program *program, char* name, CUfunction function);
106 
107             ~CudaKernel() override;
108 
109             long ndrange( void *argArray);
110         };
111 
112     private:
113         CUmodule module;
114         Ptx *ptx;
115 
116     public:
117         CudaProgram(Backend *backend, BuildInfo *buildInfo, Ptx *ptx, CUmodule module);
118 
119         ~CudaProgram();
120 
121         long getKernel(int nameLen, char *name);
122 
123         bool programOK();
124     };
125 
126 private:
127     CUdevice device;
128     CUcontext context;
129 public:
130 
131     CudaBackend(CudaConfig *config, int configSchemaLen, char *configSchema);
132 
133     CudaBackend();
134 
135     ~CudaBackend();
136 
137     int getMaxComputeUnits();
138 
139     void info();
140 
141     long compileProgram(int len, char *source);
142 
143     //static const char *errorMsg(CUresult status);
144 
145 };
146