1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.backend.codebuilders;
 26 
 27 
 28 import hat.buffer.Buffer;
 29 import hat.callgraph.KernelCallGraph;
 30 import hat.callgraph.KernelEntrypoint;
 31 import hat.optools.FuncOpWrapper;
 32 import hat.optools.InvokeOpWrapper;
 33 import hat.optools.StructuralOpWrapper;
 34 import hat.util.StreamCounter;
 35 
 36 import java.lang.foreign.GroupLayout;
 37 import jdk.incubator.code.type.ClassType;
 38 import jdk.incubator.code.type.JavaType;
 39 
 40 import java.lang.invoke.MethodHandles;
 41 import java.util.function.Consumer;
 42 
 43 public abstract class C99HATKernelBuilder<T extends C99HATKernelBuilder<T>> extends HATCodeBuilderWithContext<T> {
 44     public C99HATKernelBuilder() {
 45 
 46     }
 47 
 48     public T types() {
 49         return this
 50                 .charTypeDefs("s8_t", "byte", "boolean")
 51                 .unsignedCharTypeDefs("u8_t")
 52                 .shortTypeDefs("s16_t")
 53                 .unsignedShortTypeDefs("u16_t")
 54                 .unsignedIntTypeDefs("u32_t")
 55                 .intTypeDefs("s32_t")
 56                 .floatTypeDefs("f32_t")
 57                 .longTypeDefs("s64_t")
 58                 .unsignedLongTypeDefs("u64_t")
 59                 .typedefStructOrUnion(true, "KernelContext", _ -> {
 60                     intDeclaration("x").semicolonNl();
 61                     intDeclaration("maxX").semicolon();
 62                 });
 63 
 64     }
 65 
 66     T typedefStructOrUnion(boolean isStruct, String name, Consumer<T> consumer) {
 67         return
 68                 typedefKeyword().space().structOrUnion(isStruct).space()
 69                         .either(isStruct, _ -> suffix_s(name), _ -> suffix_u(name)).braceNlIndented(consumer)
 70                         .suffix_t(name).semicolon().nl();
 71 
 72     }
 73 
 74 
 75     public final T scope() {
 76 
 77         identifier("KernelContext_t").space().identifier("mine").semicolon().nl();
 78         identifier("KernelContext_t").asterisk().space().identifier("kc").equals().ampersand().identifier("mine").semicolon().nl();
 79         identifier("kc").rarrow().identifier("x").equals().globalId().semicolon().nl();
 80         identifier("kc").rarrow().identifier("maxX").equals().identifier("global_kc").rarrow().identifier("maxX").semicolon().nl();
 81         return self();
 82 
 83     }
 84 
 85     public abstract T globalPtrPrefix();
 86 
 87     @Override
 88     public T type(CodeBuilderContext buildContext,JavaType javaType) {
 89         if (InvokeOpWrapper.isIfaceUsingLookup(buildContext.lookup(),javaType) && javaType instanceof ClassType classType) {
 90             globalPtrPrefix().space();
 91             String name = classType.toClassName();
 92             int dotIdx = name.lastIndexOf('.');
 93             int dollarIdx = name.lastIndexOf('$');
 94             int idx = Math.max(dotIdx, dollarIdx);
 95             if (idx > 0) {
 96                 name = name.substring(idx + 1);
 97             }
 98             suffix_t(name).asterisk();
 99         } else {
100             typeName(javaType.toString());
101         }
102 
103         return self();
104     }
105 
106     public T kernelMethod(KernelCallGraph.KernelReachableResolvedMethodCall kernelReachableResolvedMethodCall) {
107         CodeBuilderContext buildContext = new CodeBuilderContext(kernelReachableResolvedMethodCall.funcOpWrapper());
108         buildContext.scope(buildContext.funcOpWrapper, () -> {
109             nl();
110             functionDeclaration(buildContext,buildContext.funcOpWrapper.getReturnType(), buildContext.funcOpWrapper.functionName());
111 
112             var list = buildContext.funcOpWrapper.paramTable.list();
113             parenNlIndented(_ ->
114                     commaSeparated(list, (info) -> type(buildContext,info.javaType).space().varName(info.varOp))
115             );
116 
117             braceNlIndented(_ -> {
118                 //scope();
119                 StreamCounter.of(buildContext.funcOpWrapper.wrappedRootOpStream(), (c, root) ->
120                         nlIf(c.isNotFirst()).recurse(buildContext, root).semicolonIf(!(root instanceof StructuralOpWrapper<?>))
121                 );
122             });
123         });
124         return self();
125     }
126 
127     public T kernelEntrypoint(KernelEntrypoint kernelEntrypoint, Object[] args) {
128 
129         nl();
130         CodeBuilderContext buildContext = new CodeBuilderContext(kernelEntrypoint.funcOpWrapper());
131         //  System.out.print(kernelReachableResolvedMethodCall.funcOpWrapper().toText());
132         buildContext.scope(buildContext.funcOpWrapper, () -> {
133 
134             kernelDeclaration(buildContext.funcOpWrapper.functionName());
135             // We skip the first arg which was KernelContext.
136             var list = buildContext.funcOpWrapper.paramTable.list();
137             for (int arg = 1; arg < args.length; arg++) {
138                 if (args[arg] instanceof Buffer buffer) {
139                     FuncOpWrapper.ParamTable.Info info = list.get(arg);
140                     info.setLayout((GroupLayout) Buffer.getLayout(buffer));
141                     info.setClass(args[arg].getClass());
142                 }
143             }
144             parenNlIndented(_ -> {
145                         globalPtrPrefix().space().suffix_t("KernelContext").space().asterisk().identifier("global_kc");
146                         list.stream().skip(1).forEach(info ->
147                                 comma().space().type(buildContext,info.javaType).space().varName(info.varOp)
148                         );
149                     }
150             );
151 
152             braceNlIndented(_ -> {
153                 scope();
154                 StreamCounter.of(buildContext.funcOpWrapper.wrappedRootOpStream(), (c, root) ->
155                         nlIf(c.isNotFirst()).recurse(buildContext, root).semicolonIf(!(root instanceof StructuralOpWrapper<?>))
156                 );
157             });
158         });
159         return self();
160     }
161 
162 
163     public abstract T defines();
164 
165     public abstract T pragmas();
166 
167     public abstract T kernelDeclaration(String name);
168 
169     public abstract T functionDeclaration(CodeBuilderContext codeBuilderContext,JavaType javaType, String name);
170 
171     public abstract T globalId();
172 
173     public abstract T globalSize();
174 
175 }