1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.jextracted;
26
27 import hat.callgraph.KernelCallGraph;
28 import hat.codebuilders.C99HATKernelBuilder;
29 import hat.dialect.BinaryOpEnum;
30 import hat.phases.HATPhaseUtils;
31 import hat.types.BF16;
32 import hat.types.F16;
33 import jdk.incubator.code.Value;
34 import jdk.incubator.code.dialect.core.CoreOp;
35 import jdk.incubator.code.dialect.core.VarType;
36 import jdk.incubator.code.dialect.java.ClassType;
37 import jdk.incubator.code.dialect.java.JavaOp;
38 import jdk.incubator.code.dialect.java.PrimitiveType;
39 import optkl.IfaceValue;
40 import optkl.OpHelper;
41 import optkl.codebuilders.CodeBuilder;
42 import optkl.codebuilders.ScopedCodeBuilderContext;
43 import jdk.incubator.code.Op;
44
45 import java.util.HashMap;
46 import java.util.Map;
47 import java.util.Optional;
48 import java.util.stream.Stream;
49
50 import static hat.phases.HATPhaseUtils.isMathLib;
51 import static optkl.IfaceValue.Vector.getVectorShape;
52
53 public class OpenCLJExtractedHATKernelBuilder extends C99HATKernelBuilder<OpenCLJExtractedHATKernelBuilder> {
54
55 // Mapping between API function names and OpenCL intrinsics for the math operations
56 private static final Map<String, String> MATH_FUNCTIONS = new HashMap<>();
57 static {
58 MATH_FUNCTIONS.put("maxf", "max");
59 MATH_FUNCTIONS.put("maxd", "max");
60 MATH_FUNCTIONS.put("maxf16", "MAX_HAT");
61 MATH_FUNCTIONS.put("minf", "min");
62 MATH_FUNCTIONS.put("mind", "min");
63 MATH_FUNCTIONS.put("minf16", "MIN_HAT");
64
65 MATH_FUNCTIONS.put("expf", "exp");
66 MATH_FUNCTIONS.put("expd", "exp");
67 MATH_FUNCTIONS.put("expf16", "half_exp");
68
69 MATH_FUNCTIONS.put("cosf", "cos");
70 MATH_FUNCTIONS.put("cosd", "cos");
71 MATH_FUNCTIONS.put("sinf", "sin");
72 MATH_FUNCTIONS.put("sind", "sin");
73 MATH_FUNCTIONS.put("tanf", "tan");
74 MATH_FUNCTIONS.put("tand", "tan");
75
76 MATH_FUNCTIONS.put("native_cosf", "native_cos");
77 MATH_FUNCTIONS.put("native_sinf", "native_sin");
78 MATH_FUNCTIONS.put("native_tanf", "native_tan");
79 MATH_FUNCTIONS.put("native_expf", "native_exp");
80
81 MATH_FUNCTIONS.put("sqrtf", "sqrt");
82 MATH_FUNCTIONS.put("sqrtd", "sqrt");
83 }
84
85 protected OpenCLJExtractedHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilderContext scopedCodeBuilderContext) {
86 super(kernelCallGraph, scopedCodeBuilderContext);
87 }
88
89 @Override
90 public OpenCLJExtractedHATKernelBuilder defines() {
91 return self()
92 .hashDefine("HAT_OPENCL")
93 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
94 .when(kernelCallGraph.isUsesAtomics(),_->pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable"))
95 .when(kernelCallGraph.isUsesAtomics(),_->pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable"))
96 /*.when(kernelCallGraph.usesFp16,_->*/.pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable")//) // Enable Half type
97 .hashDefine("HAT_FUNC", _ -> keyword(""))
98 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
99 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
100 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
101 .when(kernelCallGraph.accessedKernelContextFields.contains("gix"), _->hashDefine("HAT_GIX", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstZero()))))
102 .when(kernelCallGraph.accessedKernelContextFields.contains("giy"), _->hashDefine("HAT_GIY", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstOne()))))
103 .when(kernelCallGraph.accessedKernelContextFields.contains("giz"), _->hashDefine("HAT_GIZ", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstTwo()))))
104 .when(kernelCallGraph.accessedKernelContextFields.contains("lix"), _->hashDefine("HAT_LIX", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstZero()))))
105 .when(kernelCallGraph.accessedKernelContextFields.contains("liy"), _->hashDefine("HAT_LIY", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstOne()))))
106 .when(kernelCallGraph.accessedKernelContextFields.contains("liz"), _->hashDefine("HAT_LIZ", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstTwo()))))
107 .when(kernelCallGraph.accessedKernelContextFields.contains("gsx"), _->hashDefine("HAT_GSX", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstZero()))))
108 .when(kernelCallGraph.accessedKernelContextFields.contains("gsy"), _->hashDefine("HAT_GSY", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstOne()))))
109 .when(kernelCallGraph.accessedKernelContextFields.contains("gsz"), _->hashDefine("HAT_GSZ", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstTwo()))))
110 .when(kernelCallGraph.accessedKernelContextFields.contains("lsx"), _->hashDefine("HAT_LSX", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstZero()))))
111 .when(kernelCallGraph.accessedKernelContextFields.contains("lsy"), _->hashDefine("HAT_LSY", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstOne()))))
112 .when(kernelCallGraph.accessedKernelContextFields.contains("lsz"), _->hashDefine("HAT_LSZ", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstTwo()))))
113 .when(kernelCallGraph.accessedKernelContextFields.contains("bix"), _->hashDefine("HAT_BIX", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstZero()))))
114 .when(kernelCallGraph.accessedKernelContextFields.contains("biy"), _->hashDefine("HAT_BIY", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstOne()))))
115 .when(kernelCallGraph.accessedKernelContextFields.contains("biz"), _->hashDefine("HAT_BIZ", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstTwo()))))
116 .when(kernelCallGraph.accessedKernelContextFields.contains("bsx"), _->hashDefine("HAT_BSX", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstZero()))))
117 .when(kernelCallGraph.accessedKernelContextFields.contains("bsy"), _->hashDefine("HAT_BSY", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstOne()))))
118 .when(kernelCallGraph.accessedKernelContextFields.contains("bsz"), _->hashDefine("HAT_BSZ", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstTwo()))))
119 .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->maxMacro("MAX_HAT"))
120 .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->minMacro("MIN_HAT"))
121 .when(kernelCallGraph.isUsesBarrier(), _ ->hashDefine("HAT_BARRIER", _ -> id("barrier").oparen().id("CLK_LOCAL_MEM_FENCE").cparen()))
122 /*.when(callgraphState.usesFp16,_->*/.hashDefine("BFLOAT16", _ -> keyword("ushort"))//)
123 /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("F16", "half")//)
124 /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("BF16", "BFLOAT16")//)
125 /*.when(callgraphState.usesFp16,_->*/.unionBfloat16()//)
126 /*.when(callgraphState.usesFp16,_->*/.build_builtin_bfloat16ToFloat("bf16")//)
127 /*.when(callgraphState.usesFp16,_->*/.build_builtin_float2bfloat16("f")/*)*/;
128 }
129
130 @Override
131 public OpenCLJExtractedHATKernelBuilder atomicInc( Op.Result instanceResult, String name) {
132 return id("atomic_inc").paren(_ -> ampersand().recurse( instanceResult.op()).rarrow().id(name));
133 }
134
135 protected OpenCLJExtractedHATKernelBuilder vstore(int dims) {
136 return id("vstore" + dims);
137 }
138
139 protected OpenCLJExtractedHATKernelBuilder vload(int dims) {
140 return id("vload" + dims);
141 }
142
143
144 @Override
145 public OpenCLJExtractedHATKernelBuilder generateVectorLoad(Value source, Value index, IfaceValue.Vector.Shape vectorShape, boolean deviceAllocated) {
146 vload(vectorShape.lanes()).paren(_ -> {
147 intConstZero().comma().sp().ampersand();
148 recurseResultOrThrow(source);
149 either(deviceAllocated, CodeBuilder::dot, CodeBuilder::rarrow);
150 id("array").sbrace(_ -> recurseResultOrThrow(index));
151 });
152 return self();
153 }
154
155 @Override
156 public OpenCLJExtractedHATKernelBuilder hatVectorStoreOp(Value dest, Value index, IfaceValue.Vector.Shape vectorShape, boolean deviceAllocated, String name, Op op) {
157 return vstore(vectorShape.lanes()).paren(_-> {
158 // if the value to be stored is an operation, recurse on the operation
159 varName(name);
160 csp().intConstZero().csp().ampersand().recurseResultOrThrow(dest);
161 either(deviceAllocated, CodeBuilder::dot, CodeBuilder::rarrow);
162 id("array").sbrace(_ ->recurseResultOrThrow(index));
163 });
164 }
165
166 @Override
167 public OpenCLJExtractedHATKernelBuilder hatBinaryVectorOp(OpHelper.Invoke binOp) {
168 return paren(_-> {
169 recurseResultOrThrow(binOp.op().operands().get(0));
170 sp().id(BinaryOpEnum.of(binOp.op()).symbol()).sp();
171 recurseResultOrThrow(binOp.op().operands().get(1));
172 });
173 }
174
175 @Override
176 public OpenCLJExtractedHATKernelBuilder hatSelectStoreOp(OpHelper.Invoke invoke, HATPhaseUtils.InvokeVar invokeVar) {
177 if (invoke.op().operands().getFirst().declaringElement() instanceof JavaOp.ArrayAccessOp.ArrayLoadOp vLoadOp) {
178 recurse(vLoadOp);
179 } else {
180 id(invokeVar.name());
181 }
182 dot().id(HATPhaseUtils.mapLane(invokeVar.laneIdx())).assign();
183 String resolvedName = invokeVar.resolveName();
184 return either (resolvedName != null,
185 _-> varName(resolvedName),
186 _-> recurseResultOrThrow(invoke.op().operands().get(1))
187 );
188 }
189
190 @Override
191 public OpenCLJExtractedHATKernelBuilder hatF16ConvOp(JavaOp.InvokeOp invokeOp, Class<?> reduceFloatType) {
192 return paren(_-> f16OrBF16(reduceFloatType)).brace(_->
193 either (BF16.class.isAssignableFrom(reduceFloatType),
194 _-> builtin_float2bfloat16().paren(_-> recurseResultOrThrow(invokeOp.operands().getFirst())),
195 _-> recurseResultOrThrow(invokeOp.operands().getFirst())
196 ));
197 }
198
199 @Override
200 public OpenCLJExtractedHATKernelBuilder genVectorIdentifier(IfaceValue.Vector.Shape vectorShape) {
201 return paren(_-> id(vectorShape.codeType().toString() + vectorShape.lanes()));
202 }
203
204 @Override
205 public OpenCLJExtractedHATKernelBuilder hatF16ToFloatConvOp(OpHelper.Invoke invoke, Class<?> reducedFloatType, boolean wasFloat, boolean isF16Local) {
206 if (F16.class.isAssignableFrom(reducedFloatType)) {// half -> float
207 paren(_->f32Type());
208 } else if (BF16.class.isAssignableFrom(reducedFloatType)) {// bfloat16 -> float
209 builtin_bfloat16ToFloat();
210 }
211 parenWhen(BF16.class.isAssignableFrom(reducedFloatType),_-> {
212 recurseResultOrThrow(invoke.op().operands().getFirst());
213 if (!isF16Local) {
214 rarrow();
215 } else if (!wasFloat) {
216 dot();
217 } else{
218 throw new RuntimeException("Can we get here");
219 }
220 id("value");
221 });
222 return self();
223 }
224
225
226 @Override
227 protected String mapMathIntrinsic(String hatMathIntrinsicName) {
228 return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName);
229 }
230
231 @Override
232 protected OpenCLJExtractedHATKernelBuilder varOpForNarrowType(CoreOp.VarOp varOp) {
233 // obtain the category:
234 Value first = varOp.operands().getFirst();
235 Class<?> narrowCategory;
236 if (first.declaringElement() instanceof JavaOp.InvokeOp invokeOp) {
237 // Find the category - This is the generic case, when ALL custom ops are removed
238 Stream<OpHelper.Invoke> stream = OpHelper.Invoke.stream(kernelCallGraph.lookup(), invokeOp);
239 Optional<OpHelper.Invoke> invoke = stream.findFirst();
240 narrowCategory = HATPhaseUtils.reduceFloatType(invoke);
241 if (narrowCategory == null && isMathLib(invoke)) {
242 narrowCategory = HATPhaseUtils.reduceFloatTypeFromReturnType(invoke);
243 }
244 } else {
245 throw new IllegalStateException("Expected an invoke, but found: " + first.declaringElement().getClass());
246 }
247 if (narrowCategory == null) {
248 throw new IllegalStateException("Narrow type can't be null: ");
249 }
250 f16OrBF16(narrowCategory).sp().assign(
251 _ -> id(varOp.varName()),
252 _ -> recurse(OpHelper.asResultOrThrow(varOp.operands().getFirst()).op()));
253 return self();
254 }
255
256 @Override
257 protected OpenCLJExtractedHATKernelBuilder varOpForVectors(CoreOp.VarOp varOp) {
258 // build VectorType
259 VarType resultType = varOp.resultType();
260 if (!(resultType.valueType() instanceof PrimitiveType)) {
261 IfaceValue.Vector.Shape vectorShape = null;
262 if (resultType.valueType() instanceof ClassType classType) {
263 vectorShape = getVectorShape(kernelCallGraph.lookup(), classType);
264 } else if (resultType.valueType() instanceof VarType varType) {
265 vectorShape = getVectorShape(kernelCallGraph.lookup(), varType.valueType());
266 }
267 if (vectorShape == null) {
268 // guarantee we don't have a null shape. Otherwise. we can't generate the correct code
269 throw new IllegalStateException("Could not find vector shape");
270 }
271 // Emit
272 type(vectorShape.codeType().toString() + vectorShape.lanes());
273 sp().varName(varOp).sp().equals().sp();
274 recurseResultOrThrow(varOp.operands().getFirst());
275 }
276 return self();
277 }
278
279 @Override
280 protected OpenCLJExtractedHATKernelBuilder varOpInit(CoreOp.VarOp varOp) {
281 suffix_t((ClassType) varOp.varValueType()).sp()
282 .assign(_ -> id(varOp.varName()),
283 _ -> recurse(OpHelper.asResultOrThrow(varOp.operands().getFirst()).op()));
284 return self();
285 }
286
287 @Override
288 protected OpenCLJExtractedHATKernelBuilder varOpLocalMemory(CoreOp.VarOp varOp) {
289 HAT_LOCAL_MEM().sp();
290 return varOpPrivateMemory(varOp);
291 }
292
293 @Override
294 protected OpenCLJExtractedHATKernelBuilder varOpPrivateMemory(CoreOp.VarOp varOp) {
295 VarType resultType = varOp.resultType();
296 if (resultType.valueType() instanceof VarType varType) {
297 suffix_t((ClassType) varType.valueType());
298 } else if (resultType.valueType() instanceof ClassType classType) {
299 suffix_t(classType);
300 }
301 return sp().varName(varOp);
302 }
303 }