1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.jextracted;
 26 
 27 import hat.codebuilders.C99HATKernelBuilder;
 28 import hat.codebuilders.CodeBuilder;
 29 import hat.codebuilders.ScopedCodeBuilderContext;
 30 import hat.dialect.HATF16ConvOp;
 31 import hat.dialect.HATF16ToFloatConvOp;
 32 import hat.dialect.HATVectorBinaryOp;
 33 import hat.dialect.HATVectorLoadOp;
 34 import hat.dialect.HATVectorOfOp;
 35 import hat.dialect.HATVectorSelectLoadOp;
 36 import hat.dialect.HATVectorSelectStoreOp;
 37 import hat.dialect.HATVectorStoreView;
 38 import hat.dialect.HATVectorVarOp;
 39 import jdk.incubator.code.Op;
 40 import jdk.incubator.code.Value;
 41 
 42 public class OpenCLHatKernelBuilder extends C99HATKernelBuilder<OpenCLHatKernelBuilder> {
 43 
 44     @Override
 45     public OpenCLHatKernelBuilder defines() {
 46         return self()
 47                 .hashDefine("HAT_OPENCL")
 48                 //  .hashIfdef("HAT_OPENCL", _ ->
 49                 //        indent(_ -> self()
 50                 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
 51                 .pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable")
 52                 .pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable")
 53                 .pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable")                      // Enable Half type
 54                 .hashDefine("HAT_FUNC", _ -> keyword("inline"))
 55                 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
 56                 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
 57                 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
 58                 .hashDefine("HAT_GIX", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstZero())))
 59                 .hashDefine("HAT_GIY", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstOne())))
 60                 .hashDefine("HAT_GIZ", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstTwo())))
 61                 .hashDefine("HAT_LIX", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstZero())))
 62                 .hashDefine("HAT_LIY", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstOne())))
 63                 .hashDefine("HAT_LIZ", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstTwo())))
 64                 .hashDefine("HAT_GSX", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstZero())))
 65                 .hashDefine("HAT_GSY", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstOne())))
 66                 .hashDefine("HAT_GSZ", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstTwo())))
 67                 .hashDefine("HAT_LSX", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstZero())))
 68                 .hashDefine("HAT_LSY", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstOne())))
 69                 .hashDefine("HAT_LSZ", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstTwo())))
 70                 .hashDefine("HAT_BIX", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstZero())))
 71                 .hashDefine("HAT_BIY", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstOne())))
 72                 .hashDefine("HAT_BIZ", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstTwo())))
 73                 .hashDefine("HAT_BARRIER", _ -> identifier("barrier").oparen().identifier("CLK_LOCAL_MEM_FENCE").cparen());
 74         //         )
 75         // );
 76     }
 77 
 78     @Override
 79     public OpenCLHatKernelBuilder atomicInc(ScopedCodeBuilderContext buildContext, Op.Result instanceResult, String name) {
 80         return identifier("atomic_inc").paren(_ -> ampersand().recurse(buildContext, instanceResult.op()).rarrow().identifier(name));
 81     }
 82 
 83     @Override
 84     public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildContext, HATVectorStoreView hatVectorStoreView) {
 85         Value dest = hatVectorStoreView.operands().get(0);
 86         Value index = hatVectorStoreView.operands().get(2);
 87 
 88         identifier("vstore" + hatVectorStoreView.vectorN())
 89                 .oparen()
 90                 .varName(hatVectorStoreView)
 91                 .comma()
 92                 .space()
 93                 .intConstZero()
 94                 .comma()
 95                 .space()
 96                 .ampersand();
 97 
 98         if (dest instanceof Op.Result r) {
 99             recurse(buildContext, r.op());
100         }
101         either(hatVectorStoreView.isSharedOrPrivate(), CodeBuilder::dot, CodeBuilder::rarrow);
102         identifier("array").osbrace();
103 
104         if (index instanceof Op.Result r) {
105             recurse(buildContext, r.op());
106         }
107 
108         csbrace().cparen();
109         return self();
110     }
111 
112     @Override
113     public OpenCLHatKernelBuilder hatBinaryVectorOp(ScopedCodeBuilderContext buildContext, HATVectorBinaryOp hatVectorBinaryOp) {
114 
115         oparen();
116         Value op1 = hatVectorBinaryOp.operands().get(0);
117         Value op2 = hatVectorBinaryOp.operands().get(1);
118 
119         if (op1 instanceof Op.Result r) {
120             recurse(buildContext, r.op());
121         }
122         space().identifier(hatVectorBinaryOp.operationType().symbol()).space();
123 
124         if (op2 instanceof Op.Result r) {
125             recurse(buildContext, r.op());
126         }
127         cparen();
128         return self();
129     }
130 
131     @Override
132     public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildContext, HATVectorLoadOp hatVectorLoadOp) {
133         Value source = hatVectorLoadOp.operands().get(0);
134         Value index = hatVectorLoadOp.operands().get(1);
135 
136         identifier("vload" + hatVectorLoadOp.vectorN())
137                 .oparen()
138                 .intConstZero()
139                 .comma()
140                 .space()
141                 .ampersand();
142 
143         if (source instanceof Op.Result r) {
144             recurse(buildContext, r.op());
145         }
146 
147         either(hatVectorLoadOp.isSharedOrPrivate(), CodeBuilder::dot, CodeBuilder::rarrow);
148         identifier("array").osbrace();
149         if (index instanceof Op.Result r) {
150             recurse(buildContext, r.op());
151         }
152         csbrace().cparen();
153         return self();
154     }
155 
156     @Override
157     public OpenCLHatKernelBuilder hatSelectLoadOp(ScopedCodeBuilderContext buildContext, HATVectorSelectLoadOp hatVSelectLoadOp) {
158         identifier(hatVSelectLoadOp.varName())
159                 .dot()
160                 .identifier(hatVSelectLoadOp.mapLane());
161         return self();
162     }
163 
164     @Override
165     public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildContext, HATVectorSelectStoreOp hatVSelectStoreOp) {
166         identifier(hatVSelectStoreOp.varName())
167                 .dot()
168                 .identifier(hatVSelectStoreOp.mapLane())
169                 .space().equals().space();
170         if (hatVSelectStoreOp.resultValue() != null) {
171             // We have detected a direct resolved result (resolved name)
172             varName(hatVSelectStoreOp.resultValue());
173         } else {
174             // otherwise, we traverse to resolve the expression
175             Value storeValue = hatVSelectStoreOp.operands().get(1);
176             if (storeValue instanceof Op.Result r) {
177                 recurse(buildContext, r.op());
178             }
179         }
180         return self();
181     }
182 
183     @Override
184     public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
185         oparen().typeName("half").cparen();
186         Value initValue = hatF16ConvOp.operands().getFirst();
187         if (initValue instanceof Op.Result r) {
188             recurse(buildContext, r.op());
189         }
190         return self();
191     }
192 
193     @Override
194     public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext, HATVectorVarOp hatVectorVarOp) {
195         typeName(hatVectorVarOp.buildType())
196                 .space()
197                 .varName(hatVectorVarOp)
198                 .space().equals().space();
199 
200         Value operand = hatVectorVarOp.operands().getFirst();
201         if (operand instanceof Op.Result r) {
202             recurse(buildContext, r.op());
203         }
204         return self();
205     }
206 
207     @Override
208     public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
209         oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
210         return self();
211     }
212 
213     @Override
214     public OpenCLHatKernelBuilder hatF16ToFloatConvOp(ScopedCodeBuilderContext builderContext, HATF16ToFloatConvOp hatF16ToFloatConvOp) {
215         oparen().halfType().cparen();
216         identifier(hatF16ToFloatConvOp.varName());
217         return self();
218     }
219 
220 }