1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.jextracted;
 26 
 27 import hat.callgraph.KernelCallGraph;
 28 import hat.codebuilders.C99HATKernelBuilder;
 29 import hat.dialect.BinaryOpEnum;
 30 import hat.phases.HATPhaseUtils;
 31 import hat.types.BF16;
 32 import hat.types.F16;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.dialect.core.CoreOp;
 35 import jdk.incubator.code.dialect.core.VarType;
 36 import jdk.incubator.code.dialect.java.ClassType;
 37 import jdk.incubator.code.dialect.java.JavaOp;
 38 import jdk.incubator.code.dialect.java.PrimitiveType;
 39 import optkl.IfaceValue;
 40 import optkl.OpHelper;
 41 import optkl.codebuilders.CodeBuilder;
 42 import optkl.codebuilders.ScopedCodeBuilderContext;
 43 import jdk.incubator.code.Op;
 44 
 45 import java.util.HashMap;
 46 import java.util.Map;
 47 import java.util.Optional;
 48 import java.util.stream.Stream;
 49 
 50 import static hat.phases.HATPhaseUtils.isMathLib;
 51 import static optkl.IfaceValue.Vector.getVectorShape;
 52 
 53 public class OpenCLJExtractedHATKernelBuilder extends C99HATKernelBuilder<OpenCLJExtractedHATKernelBuilder> {
 54 
 55     // Mapping between API function names and OpenCL intrinsics for the math operations
 56     private static final Map<String, String> MATH_FUNCTIONS = new HashMap<>();
 57     static {
 58         MATH_FUNCTIONS.put("maxf", "max");
 59         MATH_FUNCTIONS.put("maxd", "max");
 60         MATH_FUNCTIONS.put("maxf16", "MAX_HAT");
 61         MATH_FUNCTIONS.put("minf", "min");
 62         MATH_FUNCTIONS.put("mind", "min");
 63         MATH_FUNCTIONS.put("minf16", "MIN_HAT");
 64 
 65         MATH_FUNCTIONS.put("expf", "exp");
 66         MATH_FUNCTIONS.put("expd", "exp");
 67         MATH_FUNCTIONS.put("expf16", "half_exp");
 68 
 69         MATH_FUNCTIONS.put("cosf", "cos");
 70         MATH_FUNCTIONS.put("cosd", "cos");
 71         MATH_FUNCTIONS.put("sinf", "sin");
 72         MATH_FUNCTIONS.put("sind", "sin");
 73         MATH_FUNCTIONS.put("tanf", "tan");
 74         MATH_FUNCTIONS.put("tand", "tan");
 75 
 76         MATH_FUNCTIONS.put("native_cosf", "native_cos");
 77         MATH_FUNCTIONS.put("native_sinf", "native_sin");
 78         MATH_FUNCTIONS.put("native_tanf", "native_tan");
 79         MATH_FUNCTIONS.put("native_expf", "native_exp");
 80 
 81         MATH_FUNCTIONS.put("sqrtf", "sqrt");
 82         MATH_FUNCTIONS.put("sqrtd", "sqrt");
 83     }
 84 
 85     protected OpenCLJExtractedHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilderContext scopedCodeBuilderContext) {
 86         super(kernelCallGraph, scopedCodeBuilderContext);
 87     }
 88 
 89     @Override
 90     public OpenCLJExtractedHATKernelBuilder defines() {
 91         return self()
 92                 .hashDefine("HAT_OPENCL")
 93                 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
 94                 .when(kernelCallGraph.isUsesAtomics(),_->pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable"))
 95                 .when(kernelCallGraph.isUsesAtomics(),_->pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable"))
 96                 /*.when(kernelCallGraph.usesFp16,_->*/.pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable")//)                      // Enable Half type
 97                 .hashDefine("HAT_FUNC", _ -> keyword(""))
 98                 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
 99                 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
100                 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
101                 .when(kernelCallGraph.accessedKernelContextFields.contains("gix"), _->hashDefine("HAT_GIX", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstZero()))))
102                 .when(kernelCallGraph.accessedKernelContextFields.contains("giy"), _->hashDefine("HAT_GIY", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstOne()))))
103                 .when(kernelCallGraph.accessedKernelContextFields.contains("giz"), _->hashDefine("HAT_GIZ", _ -> paren(_ -> id("get_global_id").paren(_ -> intConstTwo()))))
104                 .when(kernelCallGraph.accessedKernelContextFields.contains("lix"), _->hashDefine("HAT_LIX", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstZero()))))
105                 .when(kernelCallGraph.accessedKernelContextFields.contains("liy"), _->hashDefine("HAT_LIY", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstOne()))))
106                 .when(kernelCallGraph.accessedKernelContextFields.contains("liz"), _->hashDefine("HAT_LIZ", _ -> paren(_ -> id("get_local_id").paren(_ -> intConstTwo()))))
107                 .when(kernelCallGraph.accessedKernelContextFields.contains("gsx"), _->hashDefine("HAT_GSX", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstZero()))))
108                 .when(kernelCallGraph.accessedKernelContextFields.contains("gsy"), _->hashDefine("HAT_GSY", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstOne()))))
109                 .when(kernelCallGraph.accessedKernelContextFields.contains("gsz"), _->hashDefine("HAT_GSZ", _ -> paren(_ -> id("get_global_size").paren(_ -> intConstTwo()))))
110                 .when(kernelCallGraph.accessedKernelContextFields.contains("lsx"), _->hashDefine("HAT_LSX", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstZero()))))
111                 .when(kernelCallGraph.accessedKernelContextFields.contains("lsy"), _->hashDefine("HAT_LSY", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstOne()))))
112                 .when(kernelCallGraph.accessedKernelContextFields.contains("lsz"), _->hashDefine("HAT_LSZ", _ -> paren(_ -> id("get_local_size").paren(_ -> intConstTwo()))))
113                 .when(kernelCallGraph.accessedKernelContextFields.contains("bix"), _->hashDefine("HAT_BIX", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstZero()))))
114                 .when(kernelCallGraph.accessedKernelContextFields.contains("biy"), _->hashDefine("HAT_BIY", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstOne()))))
115                 .when(kernelCallGraph.accessedKernelContextFields.contains("biz"), _->hashDefine("HAT_BIZ", _ -> paren(_ -> id("get_group_id").paren(_ -> intConstTwo()))))
116                 .when(kernelCallGraph.accessedKernelContextFields.contains("bsx"), _->hashDefine("HAT_BSX", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstZero()))))
117                 .when(kernelCallGraph.accessedKernelContextFields.contains("bsy"), _->hashDefine("HAT_BSY", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstOne()))))
118                 .when(kernelCallGraph.accessedKernelContextFields.contains("bsz"), _->hashDefine("HAT_BSZ", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstTwo()))))
119                 .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->maxMacro("MAX_HAT"))
120                 .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->minMacro("MIN_HAT"))
121                 .when(kernelCallGraph.isUsesBarrier(), _ ->hashDefine("HAT_BARRIER", _ -> id("barrier").oparen().id("CLK_LOCAL_MEM_FENCE").cparen()))
122                 /*.when(callgraphState.usesFp16,_->*/.hashDefine("BFLOAT16", _ -> keyword("ushort"))//)
123                 /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("F16",  "half")//)
124                 /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("BF16",  "BFLOAT16")//)
125                 /*.when(callgraphState.usesFp16,_->*/.unionBfloat16()//)
126                 /*.when(callgraphState.usesFp16,_->*/.build_builtin_bfloat16ToFloat("bf16")//)
127                 /*.when(callgraphState.usesFp16,_->*/.build_builtin_float2bfloat16("f")/*)*/;
128     }
129 
130     @Override
131     public OpenCLJExtractedHATKernelBuilder atomicInc( Op.Result instanceResult, String name) {
132         return id("atomic_inc").paren(_ -> ampersand().recurse( instanceResult.op()).rarrow().id(name));
133     }
134 
135     protected OpenCLJExtractedHATKernelBuilder vstore(int dims) {
136         return id("vstore" + dims);
137     }
138 
139     protected OpenCLJExtractedHATKernelBuilder vload(int dims) {
140         return id("vload" + dims);
141     }
142 
143 
144     @Override
145     public OpenCLJExtractedHATKernelBuilder generateVectorLoad(Value source, Value index, IfaceValue.Vector.Shape vectorShape, boolean deviceAllocated) {
146         vload(vectorShape.lanes()).paren(_ -> {
147             intConstZero().comma().sp().ampersand();
148             recurseResultOrThrow(source);
149             either(deviceAllocated, CodeBuilder::dot, CodeBuilder::rarrow);
150             id("array").sbrace(_ -> recurseResultOrThrow(index));
151         });
152         return self();
153     }
154 
155     @Override
156     public OpenCLJExtractedHATKernelBuilder hatVectorStoreOp(Value dest, Value index, IfaceValue.Vector.Shape vectorShape, boolean deviceAllocated, String name, Op op) {
157         return vstore(vectorShape.lanes()).paren(_-> {
158             // if the value to be stored is an operation, recurse on the operation
159             varName(name);
160             csp().intConstZero().csp().ampersand().recurseResultOrThrow(dest);
161             either(deviceAllocated, CodeBuilder::dot, CodeBuilder::rarrow);
162             id("array").sbrace(_ ->recurseResultOrThrow(index));
163         });
164     }
165 
166     @Override
167     public OpenCLJExtractedHATKernelBuilder hatBinaryVectorOp(OpHelper.Invoke binOp) {
168         return paren(_-> {
169             recurseResultOrThrow(binOp.op().operands().get(0));
170             sp().id(BinaryOpEnum.of(binOp.op()).symbol()).sp();
171             recurseResultOrThrow(binOp.op().operands().get(1));
172         });
173     }
174 
175     @Override
176     public OpenCLJExtractedHATKernelBuilder hatSelectStoreOp(OpHelper.Invoke invoke, HATPhaseUtils.InvokeVar invokeVar) {
177         if (invoke.op().operands().getFirst().declaringElement() instanceof JavaOp.ArrayAccessOp.ArrayLoadOp vLoadOp) {
178             recurse(vLoadOp);
179         } else {
180             id(invokeVar.name());
181         }
182         dot().id(HATPhaseUtils.mapLane(invokeVar.laneIdx())).assign();
183         String resolvedName = invokeVar.resolveName();
184         return either (resolvedName != null,
185                 _-> varName(resolvedName),
186                 _-> recurseResultOrThrow(invoke.op().operands().get(1))
187         );
188     }
189 
190     @Override
191     public OpenCLJExtractedHATKernelBuilder hatF16ConvOp(JavaOp.InvokeOp invokeOp, Class<?> reduceFloatType) {
192         return paren(_-> f16OrBF16(reduceFloatType)).brace(_->
193                 either (BF16.class.isAssignableFrom(reduceFloatType),
194                         _-> builtin_float2bfloat16().paren(_-> recurseResultOrThrow(invokeOp.operands().getFirst())),
195                         _-> recurseResultOrThrow(invokeOp.operands().getFirst())
196                 ));
197     }
198 
199     @Override
200     public OpenCLJExtractedHATKernelBuilder genVectorIdentifier(IfaceValue.Vector.Shape vectorShape) {
201         return paren(_-> id(vectorShape.codeType().toString() + vectorShape.lanes()));
202     }
203 
204     @Override
205     public OpenCLJExtractedHATKernelBuilder hatF16ToFloatConvOp(OpHelper.Invoke invoke, Class<?> reducedFloatType, boolean wasFloat, boolean isF16Local) {
206         if (F16.class.isAssignableFrom(reducedFloatType)) {// half -> float
207             paren(_->f32Type());
208         } else if (BF16.class.isAssignableFrom(reducedFloatType)) {// bfloat16 -> float
209             builtin_bfloat16ToFloat();
210         }
211         parenWhen(BF16.class.isAssignableFrom(reducedFloatType),_-> {
212             recurseResultOrThrow(invoke.op().operands().getFirst());
213             if (!isF16Local) {
214                 rarrow();
215             } else if (!wasFloat) {
216                 dot();
217             } else{
218                 throw new RuntimeException("Can we get here");
219             }
220             id("value");
221         });
222         return self();
223     }
224 
225 
226     @Override
227     protected String mapMathIntrinsic(String hatMathIntrinsicName) {
228         return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName);
229     }
230 
231     @Override
232     protected OpenCLJExtractedHATKernelBuilder varOpForNarrowType(CoreOp.VarOp varOp) {
233         // obtain the category:
234         Value first = varOp.operands().getFirst();
235         Class<?> narrowCategory;
236         if (first.declaringElement() instanceof JavaOp.InvokeOp invokeOp) {
237             // Find the category - This is the generic case, when ALL custom ops are removed
238             Stream<OpHelper.Invoke> stream = OpHelper.Invoke.stream(kernelCallGraph.lookup(), invokeOp);
239             Optional<OpHelper.Invoke> invoke = stream.findFirst();
240             narrowCategory = HATPhaseUtils.reduceFloatType(invoke);
241             if (narrowCategory == null && isMathLib(invoke)) {
242                 narrowCategory = HATPhaseUtils.reduceFloatTypeFromReturnType(invoke);
243             }
244         } else {
245             throw new IllegalStateException("Expected an invoke, but found: " + first.declaringElement().getClass());
246         }
247         if (narrowCategory == null) {
248             throw new IllegalStateException("Narrow type can't be null: ");
249         }
250         f16OrBF16(narrowCategory).sp().assign(
251                 _ -> id(varOp.varName()),
252                 _ -> recurse(OpHelper.asResultOrThrow(varOp.operands().getFirst()).op()));
253         return self();
254     }
255 
256     @Override
257     protected OpenCLJExtractedHATKernelBuilder varOpForVectors(CoreOp.VarOp varOp) {
258         // build VectorType
259         VarType resultType = varOp.resultType();
260         if (!(resultType.valueType() instanceof PrimitiveType)) {
261             IfaceValue.Vector.Shape vectorShape = null;
262             if (resultType.valueType() instanceof ClassType classType) {
263                 vectorShape = getVectorShape(kernelCallGraph.lookup(), classType);
264             } else if (resultType.valueType() instanceof VarType varType) {
265                 vectorShape = getVectorShape(kernelCallGraph.lookup(), varType.valueType());
266             }
267             if (vectorShape == null) {
268                 // guarantee we don't have a null shape. Otherwise. we can't generate the correct code
269                 throw new IllegalStateException("Could not find vector shape");
270             }
271             // Emit
272             type(vectorShape.codeType().toString() + vectorShape.lanes());
273             sp().varName(varOp).sp().equals().sp();
274             recurseResultOrThrow(varOp.operands().getFirst());
275         }
276         return self();
277     }
278 
279     @Override
280     protected OpenCLJExtractedHATKernelBuilder varOpInit(CoreOp.VarOp varOp) {
281         suffix_t((ClassType) varOp.varValueType()).sp()
282                 .assign(_ -> id(varOp.varName()),
283                         _ -> recurse(OpHelper.asResultOrThrow(varOp.operands().getFirst()).op()));
284         return self();
285     }
286 
287     @Override
288     protected OpenCLJExtractedHATKernelBuilder varOpLocalMemory(CoreOp.VarOp varOp) {
289         HAT_LOCAL_MEM().sp();
290         return varOpPrivateMemory(varOp);
291     }
292 
293     @Override
294     protected OpenCLJExtractedHATKernelBuilder varOpPrivateMemory(CoreOp.VarOp varOp) {
295         VarType resultType = varOp.resultType();
296         if (resultType.valueType() instanceof VarType varType) {
297             suffix_t((ClassType) varType.valueType());
298         } else if (resultType.valueType() instanceof ClassType classType) {
299             suffix_t(classType);
300         }
301         return sp().varName(varOp);
302     }
303 }