1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.codebuilders;
 26 
 27 
 28 import hat.buffer.Buffer;
 29 import hat.callgraph.KernelCallGraph;
 30 import hat.callgraph.KernelEntrypoint;
 31 import hat.optools.FuncOpWrapper;
 32 import hat.optools.InvokeOpWrapper;
 33 import hat.optools.StructuralOpWrapper;
 34 import hat.util.StreamCounter;
 35 
 36 import java.lang.foreign.GroupLayout;
 37 import jdk.incubator.code.type.ClassType;
 38 import jdk.incubator.code.type.JavaType;
 39 import java.util.function.Consumer;
 40 
 41 public abstract class C99HATKernelBuilder<T extends C99HATKernelBuilder<T>> extends HATCodeBuilderWithContext<T> {
 42     public C99HATKernelBuilder() {
 43 
 44     }
 45 
 46     public T types() {
 47         return this
 48                 .charTypeDefs("s8_t", "byte", "boolean")
 49                 .unsignedCharTypeDefs("u8_t")
 50                 .shortTypeDefs("s16_t")
 51                 .unsignedShortTypeDefs("u16_t")
 52                 .unsignedIntTypeDefs("u32_t")
 53                 .intTypeDefs("s32_t")
 54                 .floatTypeDefs("f32_t")
 55                 .longTypeDefs("s64_t")
 56                 .unsignedLongTypeDefs("u64_t")
 57                 .typedefStructOrUnion(true, "KernelContext", _ -> {
 58                     intDeclaration("x").semicolonNl();
 59                     intDeclaration("maxX").semicolon();
 60                 });
 61 
 62     }
 63 
 64     T typedefStructOrUnion(boolean isStruct, String name, Consumer<T> consumer) {
 65         return
 66                 typedefKeyword().space().structOrUnion(isStruct).space()
 67                         .either(isStruct, _ -> suffix_s(name), _ -> suffix_u(name)).braceNlIndented(consumer)
 68                         .suffix_t(name).semicolon().nl();
 69 
 70     }
 71 
 72 
 73     public final T scope() {
 74 
 75         identifier("KernelContext_t").space().identifier("mine").semicolon().nl();
 76         identifier("KernelContext_t").asterisk().space().identifier("kc").equals().ampersand().identifier("mine").semicolon().nl();
 77         identifier("kc").rarrow().identifier("x").equals().globalId().semicolon().nl();
 78         identifier("kc").rarrow().identifier("maxX").equals().identifier("global_kc").rarrow().identifier("maxX").semicolon().nl();
 79         return self();
 80 
 81     }
 82 
 83     public abstract T globalPtrPrefix();
 84 
 85     @Override
 86     public T type(JavaType javaType) {
 87         if (InvokeOpWrapper.isIface(javaType) && javaType instanceof ClassType classType) {
 88             globalPtrPrefix().space();
 89             String name = classType.toClassName();
 90             int dotIdx = name.lastIndexOf('.');
 91             int dollarIdx = name.lastIndexOf('$');
 92             int idx = Math.max(dotIdx, dollarIdx);
 93             if (idx > 0) {
 94                 name = name.substring(idx + 1);
 95             }
 96             suffix_t(name).asterisk();
 97         } else {
 98             typeName(javaType.toString());
 99         }
100 
101         return self();
102     }
103 
104     public T kernelMethod(KernelCallGraph.KernelReachableResolvedMethodCall kernelReachableResolvedMethodCall) {
105         CodeBuilderContext buildContext = new CodeBuilderContext(kernelReachableResolvedMethodCall.funcOpWrapper());
106         buildContext.scope(buildContext.funcOpWrapper, () -> {
107             nl();
108             functionDeclaration(buildContext.funcOpWrapper.getReturnType(), buildContext.funcOpWrapper.functionName());
109 
110             var list = buildContext.funcOpWrapper.paramTable.list();
111             parenNlIndented(_ ->
112                     commaSeparated(list, (info) -> type(info.javaType).space().varName(info.varOp))
113             );
114 
115             braceNlIndented(_ -> {
116                 //scope();
117                 StreamCounter.of(buildContext.funcOpWrapper.wrappedRootOpStream(), (c, root) ->
118                         nlIf(c.isNotFirst()).recurse(buildContext, root).semicolonIf(!(root instanceof StructuralOpWrapper<?>))
119                 );
120             });
121         });
122         return self();
123     }
124 
125     public T kernelEntrypoint(KernelEntrypoint kernelEntrypoint, Object[] args) {
126 
127         nl();
128         CodeBuilderContext buildContext = new CodeBuilderContext(kernelEntrypoint.funcOpWrapper());
129         //  System.out.print(kernelReachableResolvedMethodCall.funcOpWrapper().toText());
130         buildContext.scope(buildContext.funcOpWrapper, () -> {
131 
132             kernelDeclaration(buildContext.funcOpWrapper.functionName());
133             // We skip the first arg which was KernelContext.
134             var list = buildContext.funcOpWrapper.paramTable.list();
135             for (int arg = 1; arg < args.length; arg++) {
136                 if (args[arg] instanceof Buffer buffer) {
137                     FuncOpWrapper.ParamTable.Info info = list.get(arg);
138                     info.setLayout((GroupLayout) Buffer.getLayout(buffer));
139                     info.setClass(args[arg].getClass());
140                 }
141             }
142             parenNlIndented(_ -> {
143                         globalPtrPrefix().space().suffix_t("KernelContext").space().asterisk().identifier("global_kc");
144                         list.stream().skip(1).forEach(info ->
145                                 comma().space().type(info.javaType).space().varName(info.varOp)
146                         );
147                     }
148             );
149 
150             braceNlIndented(_ -> {
151                 scope();
152                 StreamCounter.of(buildContext.funcOpWrapper.wrappedRootOpStream(), (c, root) ->
153                         nlIf(c.isNotFirst()).recurse(buildContext, root).semicolonIf(!(root instanceof StructuralOpWrapper<?>))
154                 );
155             });
156         });
157         return self();
158     }
159 
160 
161     public abstract T defines();
162 
163     public abstract T pragmas();
164 
165     public abstract T kernelDeclaration(String name);
166 
167     public abstract T functionDeclaration(JavaType javaType, String name);
168 
169     public abstract T globalId();
170 
171     public abstract T globalSize();
172 
173 }