1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.jextracted;
26
27 import hat.codebuilders.C99HATKernelBuilder;
28 import hat.codebuilders.CodeBuilder;
29 import hat.codebuilders.ScopedCodeBuilderContext;
30 import hat.dialect.HATF16ConvOp;
31 import hat.dialect.HATF16ToFloatConvOp;
32 import hat.dialect.HATVectorBinaryOp;
33 import hat.dialect.HATVectorLoadOp;
34 import hat.dialect.HATVectorOfOp;
35 import hat.dialect.HATVectorSelectLoadOp;
36 import hat.dialect.HATVectorSelectStoreOp;
37 import hat.dialect.HATVectorStoreView;
38 import hat.dialect.HATVectorVarOp;
39 import jdk.incubator.code.Op;
40 import jdk.incubator.code.Value;
41
42 public class OpenCLHatKernelBuilder extends C99HATKernelBuilder<OpenCLHatKernelBuilder> {
43
44 @Override
45 public OpenCLHatKernelBuilder defines() {
46 return self()
47 .hashDefine("HAT_OPENCL")
48 // .hashIfdef("HAT_OPENCL", _ ->
49 // indent(_ -> self()
50 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
51 .pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable")
52 .pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable")
53 .pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable") // Enable Half type
54 .hashDefine("HAT_FUNC", _ -> keyword("inline"))
55 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
56 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
57 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
58 .hashDefine("HAT_GIX", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstZero())))
59 .hashDefine("HAT_GIY", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstOne())))
60 .hashDefine("HAT_GIZ", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstTwo())))
61 .hashDefine("HAT_LIX", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstZero())))
62 .hashDefine("HAT_LIY", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstOne())))
63 .hashDefine("HAT_LIZ", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstTwo())))
64 .hashDefine("HAT_GSX", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstZero())))
65 .hashDefine("HAT_GSY", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstOne())))
66 .hashDefine("HAT_GSZ", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstTwo())))
67 .hashDefine("HAT_LSX", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstZero())))
68 .hashDefine("HAT_LSY", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstOne())))
69 .hashDefine("HAT_LSZ", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstTwo())))
70 .hashDefine("HAT_BIX", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstZero())))
71 .hashDefine("HAT_BIY", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstOne())))
72 .hashDefine("HAT_BIZ", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstTwo())))
73 .hashDefine("HAT_BARRIER", _ -> identifier("barrier").oparen().identifier("CLK_LOCAL_MEM_FENCE").cparen());
74 // )
75 // );
76 }
77
78 @Override
79 public OpenCLHatKernelBuilder atomicInc(ScopedCodeBuilderContext buildContext, Op.Result instanceResult, String name) {
80 return identifier("atomic_inc").paren(_ -> ampersand().recurse(buildContext, instanceResult.op()).rarrow().identifier(name));
81 }
82
83 @Override
84 public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildContext, HATVectorStoreView hatVectorStoreView) {
85 Value dest = hatVectorStoreView.operands().get(0);
86 Value index = hatVectorStoreView.operands().get(2);
87
88 identifier("vstore" + hatVectorStoreView.vectorN())
89 .oparen()
90 .varName(hatVectorStoreView)
91 .comma()
92 .space()
93 .intConstZero()
94 .comma()
95 .space()
96 .ampersand();
97
98 if (dest instanceof Op.Result r) {
99 recurse(buildContext, r.op());
100 }
101 either(hatVectorStoreView.isSharedOrPrivate(), CodeBuilder::dot, CodeBuilder::rarrow);
102 identifier("array").osbrace();
103
104 if (index instanceof Op.Result r) {
105 recurse(buildContext, r.op());
106 }
107
108 csbrace().cparen();
109 return self();
110 }
111
112 @Override
113 public OpenCLHatKernelBuilder hatBinaryVectorOp(ScopedCodeBuilderContext buildContext, HATVectorBinaryOp hatVectorBinaryOp) {
114
115 oparen();
116 Value op1 = hatVectorBinaryOp.operands().get(0);
117 Value op2 = hatVectorBinaryOp.operands().get(1);
118
119 if (op1 instanceof Op.Result r) {
120 recurse(buildContext, r.op());
121 }
122 space().identifier(hatVectorBinaryOp.operationType().symbol()).space();
123
124 if (op2 instanceof Op.Result r) {
125 recurse(buildContext, r.op());
126 }
127 cparen();
128 return self();
129 }
130
131 @Override
132 public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildContext, HATVectorLoadOp hatVectorLoadOp) {
133 Value source = hatVectorLoadOp.operands().get(0);
134 Value index = hatVectorLoadOp.operands().get(1);
135
136 identifier("vload" + hatVectorLoadOp.vectorN())
137 .oparen()
138 .intConstZero()
139 .comma()
140 .space()
141 .ampersand();
142
143 if (source instanceof Op.Result r) {
144 recurse(buildContext, r.op());
145 }
146
147 either(hatVectorLoadOp.isSharedOrPrivate(), CodeBuilder::dot, CodeBuilder::rarrow);
148 identifier("array").osbrace();
149 if (index instanceof Op.Result r) {
150 recurse(buildContext, r.op());
151 }
152 csbrace().cparen();
153 return self();
154 }
155
156 @Override
157 public OpenCLHatKernelBuilder hatSelectLoadOp(ScopedCodeBuilderContext buildContext, HATVectorSelectLoadOp hatVSelectLoadOp) {
158 identifier(hatVSelectLoadOp.varName())
159 .dot()
160 .identifier(hatVSelectLoadOp.mapLane());
161 return self();
162 }
163
164 @Override
165 public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildContext, HATVectorSelectStoreOp hatVSelectStoreOp) {
166 identifier(hatVSelectStoreOp.varName())
167 .dot()
168 .identifier(hatVSelectStoreOp.mapLane())
169 .space().equals().space();
170 if (hatVSelectStoreOp.resultValue() != null) {
171 // We have detected a direct resolved result (resolved name)
172 varName(hatVSelectStoreOp.resultValue());
173 } else {
174 // otherwise, we traverse to resolve the expression
175 Value storeValue = hatVSelectStoreOp.operands().get(1);
176 if (storeValue instanceof Op.Result r) {
177 recurse(buildContext, r.op());
178 }
179 }
180 return self();
181 }
182
183 @Override
184 public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
185 oparen().typeName("half").cparen();
186 Value initValue = hatF16ConvOp.operands().getFirst();
187 if (initValue instanceof Op.Result r) {
188 recurse(buildContext, r.op());
189 }
190 return self();
191 }
192
193 @Override
194 public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext, HATVectorVarOp hatVectorVarOp) {
195 typeName(hatVectorVarOp.buildType())
196 .space()
197 .varName(hatVectorVarOp)
198 .space().equals().space();
199
200 Value operand = hatVectorVarOp.operands().getFirst();
201 if (operand instanceof Op.Result r) {
202 recurse(buildContext, r.op());
203 }
204 return self();
205 }
206
207 @Override
208 public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
209 oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
210 return self();
211 }
212
213 @Override
214 public OpenCLHatKernelBuilder hatF16ToFloatConvOp(ScopedCodeBuilderContext builderContext, HATF16ToFloatConvOp hatF16ToFloatConvOp) {
215 oparen().halfType().cparen();
216 identifier(hatF16ToFloatConvOp.varName());
217 return self();
218 }
219
220 }