1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.backend.jextracted;
26
27 import hat.codebuilders.C99HATKernelBuilder;
28 import hat.dialect.HATF16Op;
29 import hat.dialect.HATVectorOp;
30 import optkl.codebuilders.CodeBuilder;
31 import optkl.codebuilders.ScopedCodeBuilderContext;
32 import jdk.incubator.code.Op;
33 import jdk.incubator.code.Value;
34
35 public class OpenCLJExtractedHATKernelBuilder extends C99HATKernelBuilder<OpenCLJExtractedHATKernelBuilder> {
36
37 protected OpenCLJExtractedHATKernelBuilder(ScopedCodeBuilderContext scopedCodeBuilderContext) {
38 super(scopedCodeBuilderContext);
39 }
40
41 @Override
42 public OpenCLJExtractedHATKernelBuilder defines() {
43 return self()
44 .hashDefine("HAT_OPENCL")
45 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
46 .pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable")
47 .pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable")
48 .pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable") // Enable Half type
49 .hashDefine("HAT_FUNC", _ -> keyword(""))
50 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
51 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
52 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
53 .hashDefine("HAT_GIX", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstZero())))
54 .hashDefine("HAT_GIY", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstOne())))
55 .hashDefine("HAT_GIZ", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstTwo())))
56 .hashDefine("HAT_LIX", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstZero())))
57 .hashDefine("HAT_LIY", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstOne())))
58 .hashDefine("HAT_LIZ", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstTwo())))
59 .hashDefine("HAT_GSX", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstZero())))
60 .hashDefine("HAT_GSY", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstOne())))
61 .hashDefine("HAT_GSZ", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstTwo())))
62 .hashDefine("HAT_LSX", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstZero())))
63 .hashDefine("HAT_LSY", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstOne())))
64 .hashDefine("HAT_LSZ", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstTwo())))
65 .hashDefine("HAT_BIX", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstZero())))
66 .hashDefine("HAT_BIY", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstOne())))
67 .hashDefine("HAT_BIZ", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstTwo())))
68 .hashDefine("HAT_BARRIER", _ -> identifier("barrier").oparen().identifier("CLK_LOCAL_MEM_FENCE").cparen());
69 // )
70 // );
71 }
72
73 @Override
74 public OpenCLJExtractedHATKernelBuilder atomicInc( Op.Result instanceResult, String name) {
75 return identifier("atomic_inc").paren(_ -> ampersand().recurse( instanceResult.op()).rarrow().identifier(name));
76 }
77
78 @Override
79 public OpenCLJExtractedHATKernelBuilder hatVectorStoreOp( HATVectorOp.HATVectorStoreView hatVectorStoreView) {
80 Value dest = hatVectorStoreView.operands().get(0);
81 Value index = hatVectorStoreView.operands().get(2);
82
83 identifier("vstore" + hatVectorStoreView.vectorShape().lanes())
84 .oparen()
85 .varName(hatVectorStoreView)
86 .comma()
87 .space()
88 .intConstZero()
89 .comma()
90 .space()
91 .ampersand();
92
93 if (dest instanceof Op.Result r) {
94 recurse( r.op());
95 }
96 either(hatVectorStoreView instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
97 identifier("array").osbrace();
98
99 if (index instanceof Op.Result r) {
100 recurse( r.op());
101 }
102
103 csbrace().cparen();
104 return self();
105 }
106
107 @Override
108 public OpenCLJExtractedHATKernelBuilder hatBinaryVectorOp( HATVectorOp.HATVectorBinaryOp hatVectorBinaryOp) {
109
110 oparen();
111 Value op1 = hatVectorBinaryOp.operands().get(0);
112 Value op2 = hatVectorBinaryOp.operands().get(1);
113
114 if (op1 instanceof Op.Result r) {
115 recurse( r.op());
116 }
117 space().identifier(hatVectorBinaryOp.operationType().symbol()).space();
118
119 if (op2 instanceof Op.Result r) {
120 recurse( r.op());
121 }
122 cparen();
123 return self();
124 }
125
126 @Override
127 public OpenCLJExtractedHATKernelBuilder hatVectorLoadOp( HATVectorOp.HATVectorLoadOp hatVectorLoadOp) {
128 Value source = hatVectorLoadOp.operands().get(0);
129 Value index = hatVectorLoadOp.operands().get(1);
130
131 identifier("vload" + hatVectorLoadOp.vectorShape().lanes())
132 .oparen()
133 .intConstZero()
134 .comma()
135 .space()
136 .ampersand();
137
138 if (source instanceof Op.Result r) {
139 recurse(r.op());
140 }
141
142 either(hatVectorLoadOp instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
143 identifier("array").osbrace();
144 if (index instanceof Op.Result r) {
145 recurse( r.op());
146 }
147 csbrace().cparen();
148 return self();
149 }
150
151 @Override
152 public OpenCLJExtractedHATKernelBuilder hatSelectLoadOp( HATVectorOp.HATVectorSelectLoadOp hatVSelectLoadOp) {
153 identifier(hatVSelectLoadOp.varName())
154 .dot()
155 .identifier(hatVSelectLoadOp.mapLane());
156 return self();
157 }
158
159 @Override
160 public OpenCLJExtractedHATKernelBuilder hatSelectStoreOp( HATVectorOp.HATVectorSelectStoreOp hatVSelectStoreOp) {
161 identifier(hatVSelectStoreOp.varName())
162 .dot()
163 .identifier(hatVSelectStoreOp.mapLane())
164 .space().equals().space();
165 if (hatVSelectStoreOp.resolvedName() != null) {
166 // We have detected a direct resolved result (resolved name)
167 varName(hatVSelectStoreOp.resolvedName());
168 } else if (hatVSelectStoreOp.operands().get(1) instanceof Op.Result r) {
169 recurse( r.op());
170
171 }
172 return self();
173 }
174
175 @Override
176 public OpenCLJExtractedHATKernelBuilder hatF16ConvOp( HATF16Op.HATF16ConvOp hatF16ConvOp) {
177 paren(_->typeName("half"));
178 if (hatF16ConvOp.operands().getFirst() instanceof Op.Result r) {
179 recurse( r.op());
180 }
181 return self();
182 }
183
184 @Override
185 public OpenCLJExtractedHATKernelBuilder hatVectorVarOp( HATVectorOp.HATVectorVarOp hatVectorVarOp) {
186 typeName(hatVectorVarOp.buildType())
187 .space()
188 .varName(hatVectorVarOp)
189 .space().equals().space();
190
191 Value operand = hatVectorVarOp.operands().getFirst();
192 if (operand instanceof Op.Result r) {
193 recurse( r.op());
194 }
195 return self();
196 }
197
198 @Override
199 public OpenCLJExtractedHATKernelBuilder genVectorIdentifier( HATVectorOp.HATVectorOfOp hatVectorOfOp) {
200 return paren(_->identifier(hatVectorOfOp.buildType()));
201 }
202
203 @Override
204 public OpenCLJExtractedHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp hatF16ToFloatConvOp) {
205 return paren(_-> f16Type()).identifier(hatF16ToFloatConvOp.varName());
206 }
207
208 }