1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.jextracted;
 26 
 27 import hat.codebuilders.C99HATKernelBuilder;
 28 import hat.dialect.HATF16Op;
 29 import hat.dialect.HATVectorOp;
 30 import optkl.codebuilders.CodeBuilder;
 31 import optkl.codebuilders.ScopedCodeBuilderContext;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.Value;
 34 
 35 public class OpenCLJExtractedHATKernelBuilder extends C99HATKernelBuilder<OpenCLJExtractedHATKernelBuilder> {
 36 
 37     protected OpenCLJExtractedHATKernelBuilder(ScopedCodeBuilderContext scopedCodeBuilderContext) {
 38         super(scopedCodeBuilderContext);
 39     }
 40 
 41     @Override
 42     public OpenCLJExtractedHATKernelBuilder defines() {
 43         return self()
 44                 .hashDefine("HAT_OPENCL")
 45                 .hashIfndef("NULL", _ -> hashDefine("NULL", "0"))
 46                 .pragma("OPENCL", "EXTENSION", "cl_khr_global_int32_base_atomics", ":", "enable")
 47                 .pragma("OPENCL", "EXTENSION", "cl_khr_local_int32_base_atomics", ":", "enable")
 48                 .pragma("OPENCL", "EXTENSION", "cl_khr_fp16", ":", "enable")                      // Enable Half type
 49                 .hashDefine("HAT_FUNC", _ -> keyword(""))
 50                 .hashDefine("HAT_KERNEL", _ -> keyword("__kernel"))
 51                 .hashDefine("HAT_GLOBAL_MEM", _ -> keyword("__global"))
 52                 .hashDefine("HAT_LOCAL_MEM", _ -> keyword("__local"))
 53                 .hashDefine("HAT_GIX", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstZero())))
 54                 .hashDefine("HAT_GIY", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstOne())))
 55                 .hashDefine("HAT_GIZ", _ -> paren(_ -> identifier("get_global_id").paren(_ -> intConstTwo())))
 56                 .hashDefine("HAT_LIX", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstZero())))
 57                 .hashDefine("HAT_LIY", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstOne())))
 58                 .hashDefine("HAT_LIZ", _ -> paren(_ -> identifier("get_local_id").paren(_ -> intConstTwo())))
 59                 .hashDefine("HAT_GSX", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstZero())))
 60                 .hashDefine("HAT_GSY", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstOne())))
 61                 .hashDefine("HAT_GSZ", _ -> paren(_ -> identifier("get_global_size").paren(_ -> intConstTwo())))
 62                 .hashDefine("HAT_LSX", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstZero())))
 63                 .hashDefine("HAT_LSY", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstOne())))
 64                 .hashDefine("HAT_LSZ", _ -> paren(_ -> identifier("get_local_size").paren(_ -> intConstTwo())))
 65                 .hashDefine("HAT_BIX", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstZero())))
 66                 .hashDefine("HAT_BIY", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstOne())))
 67                 .hashDefine("HAT_BIZ", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstTwo())))
 68                 .hashDefine("HAT_BARRIER", _ -> identifier("barrier").oparen().identifier("CLK_LOCAL_MEM_FENCE").cparen());
 69         //         )
 70         // );
 71     }
 72 
 73     @Override
 74     public OpenCLJExtractedHATKernelBuilder atomicInc( Op.Result instanceResult, String name) {
 75         return identifier("atomic_inc").paren(_ -> ampersand().recurse( instanceResult.op()).rarrow().identifier(name));
 76     }
 77 
 78     @Override
 79     public OpenCLJExtractedHATKernelBuilder hatVectorStoreOp( HATVectorOp.HATVectorStoreView hatVectorStoreView) {
 80         Value dest = hatVectorStoreView.operands().get(0);
 81         Value index = hatVectorStoreView.operands().get(2);
 82 
 83         identifier("vstore" + hatVectorStoreView.vectorShape().lanes())
 84                 .oparen()
 85                 .varName(hatVectorStoreView)
 86                 .comma()
 87                 .space()
 88                 .intConstZero()
 89                 .comma()
 90                 .space()
 91                 .ampersand();
 92 
 93         if (dest instanceof Op.Result r) {
 94             recurse( r.op());
 95         }
 96         either(hatVectorStoreView instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
 97         identifier("array").osbrace();
 98 
 99         if (index instanceof Op.Result r) {
100             recurse( r.op());
101         }
102 
103         csbrace().cparen();
104         return self();
105     }
106 
107     @Override
108     public OpenCLJExtractedHATKernelBuilder hatBinaryVectorOp( HATVectorOp.HATVectorBinaryOp hatVectorBinaryOp) {
109 
110         oparen();
111         Value op1 = hatVectorBinaryOp.operands().get(0);
112         Value op2 = hatVectorBinaryOp.operands().get(1);
113 
114         if (op1 instanceof Op.Result r) {
115             recurse( r.op());
116         }
117         space().identifier(hatVectorBinaryOp.operationType().symbol()).space();
118 
119         if (op2 instanceof Op.Result r) {
120             recurse( r.op());
121         }
122         cparen();
123         return self();
124     }
125 
126     @Override
127     public OpenCLJExtractedHATKernelBuilder hatVectorLoadOp( HATVectorOp.HATVectorLoadOp hatVectorLoadOp) {
128         Value source = hatVectorLoadOp.operands().get(0);
129         Value index = hatVectorLoadOp.operands().get(1);
130 
131         identifier("vload" + hatVectorLoadOp.vectorShape().lanes())
132                 .oparen()
133                 .intConstZero()
134                 .comma()
135                 .space()
136                 .ampersand();
137 
138         if (source instanceof Op.Result r) {
139             recurse(r.op());
140         }
141 
142         either(hatVectorLoadOp instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
143         identifier("array").osbrace();
144         if (index instanceof Op.Result r) {
145             recurse( r.op());
146         }
147         csbrace().cparen();
148         return self();
149     }
150 
151     @Override
152     public OpenCLJExtractedHATKernelBuilder hatSelectLoadOp( HATVectorOp.HATVectorSelectLoadOp hatVSelectLoadOp) {
153         identifier(hatVSelectLoadOp.varName())
154                 .dot()
155                 .identifier(hatVSelectLoadOp.mapLane());
156         return self();
157     }
158 
159     @Override
160     public OpenCLJExtractedHATKernelBuilder hatSelectStoreOp( HATVectorOp.HATVectorSelectStoreOp hatVSelectStoreOp) {
161         identifier(hatVSelectStoreOp.varName())
162                 .dot()
163                 .identifier(hatVSelectStoreOp.mapLane())
164                 .space().equals().space();
165         if (hatVSelectStoreOp.resolvedName() != null) {
166             // We have detected a direct resolved result (resolved name)
167             varName(hatVSelectStoreOp.resolvedName());
168         } else if (hatVSelectStoreOp.operands().get(1) instanceof Op.Result r) {
169                 recurse( r.op());
170 
171         }
172         return self();
173     }
174 
175     @Override
176     public OpenCLJExtractedHATKernelBuilder hatF16ConvOp( HATF16Op.HATF16ConvOp hatF16ConvOp) {
177         paren(_->typeName("half"));
178         if (hatF16ConvOp.operands().getFirst() instanceof Op.Result r) {
179             recurse( r.op());
180         }
181         return self();
182     }
183 
184     @Override
185     public OpenCLJExtractedHATKernelBuilder hatVectorVarOp( HATVectorOp.HATVectorVarOp hatVectorVarOp) {
186         typeName(hatVectorVarOp.buildType())
187                 .space()
188                 .varName(hatVectorVarOp)
189                 .space().equals().space();
190 
191         Value operand = hatVectorVarOp.operands().getFirst();
192         if (operand instanceof Op.Result r) {
193             recurse( r.op());
194         }
195         return self();
196     }
197 
198     @Override
199     public OpenCLJExtractedHATKernelBuilder genVectorIdentifier( HATVectorOp.HATVectorOfOp hatVectorOfOp) {
200         return paren(_->identifier(hatVectorOfOp.buildType()));
201     }
202 
203     @Override
204     public OpenCLJExtractedHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp hatF16ToFloatConvOp) {
205         return paren(_-> f16Type()).identifier(hatF16ToFloatConvOp.varName());
206     }
207 
208 }