1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat.backend;
 27 
 28 import hat.ComputeContext;
 29 import hat.NDRange;
 30 import hat.backend.c99codebuilders.C99HatKernelBuilder;
 31 import hat.buffer.ArgArray;
 32 import hat.buffer.Buffer;
 33 import hat.buffer.KernelContext;
 34 import hat.callgraph.KernelCallGraph;
 35 import hat.ifacemapper.BoundSchema;
 36 import hat.ifacemapper.Schema;
 37 
 38 import java.util.Arrays;
 39 import java.util.HashMap;
 40 import java.util.LinkedHashSet;
 41 import java.util.Map;
 42 import java.util.Set;
 43 
 44 public abstract class C99NativeBackend extends NativeBackend {
 45     public C99NativeBackend(String libName) {
 46         super(libName);
 47     }
 48 
 49     static class CompiledKernel {
 50         public final C99NativeBackend c99NativeBackend;
 51         public final KernelCallGraph kernelCallGraph;
 52         public final String text;
 53         public final long kernelHandle;
 54         public final ArgArray argArray;
 55         public final KernelContext kernelContext;
 56 
 57         CompiledKernel(C99NativeBackend c99NativeBackend, KernelCallGraph kernelCallGraph, String text, long kernelHandle, Object[] ndRangeAndArgs) {
 58             this.c99NativeBackend = c99NativeBackend;
 59             this.kernelCallGraph = kernelCallGraph;
 60             this.text = text;
 61             this.kernelHandle = kernelHandle;
 62             this.kernelContext = KernelContext.create(kernelCallGraph.computeContext.accelerator, 0, 0);
 63             ndRangeAndArgs[0] = this.kernelContext;
 64             this.argArray = ArgArray.create(kernelCallGraph.computeContext.accelerator, kernelCallGraph.computeContext.runtimeInfo, ndRangeAndArgs);
 65         }
 66 
 67         public void dispatch(NDRange ndRange, Object[] args) {
 68             kernelContext.maxX(ndRange.kid.maxX);
 69             args[0] = this.kernelContext;
 70             ArgArray.update(argArray, kernelCallGraph.computeContext.runtimeInfo, args);
 71             c99NativeBackend.ndRange(kernelHandle, this.argArray);
 72         }
 73     }
 74 
 75     Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
 76 
 77     public <T extends C99HatKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object[] args) {
 78         builder.defines().pragmas().types();
 79         Set<Schema.IfaceType> already = new LinkedHashSet<>();
 80         Arrays.stream(args)
 81                 .filter(arg -> arg instanceof Buffer)
 82                 .map(arg -> (Buffer) arg)
 83                 .forEach(ifaceBuffer -> {
 84                     BoundSchema<?> boundSchema = Buffer.getBoundSchema(ifaceBuffer);
 85                     boundSchema.schema().rootIfaceType.visitTypes(0, t -> {
 86                         if (!already.contains(t)) {
 87                             builder.typedef(boundSchema, t);
 88                             already.add(t);
 89                         }
 90                     });
 91                 });
 92 
 93         // Sorting by rank ensures we don't need forward declarations
 94         kernelCallGraph.kernelReachableResolvedStream().sorted((lhs, rhs) -> rhs.rank - lhs.rank)
 95                 .forEach(kernelReachableResolvedMethod -> builder.nl().kernelMethod(kernelReachableResolvedMethod).nl());
 96 
 97         builder.nl().kernelEntrypoint(kernelCallGraph.entrypoint, args).nl();
 98 
 99         System.out.println("Original");
100         System.out.println(kernelCallGraph.entrypoint.funcOpWrapper().op().toText());
101         System.out.println("Lowered");
102         System.out.println(kernelCallGraph.entrypoint.funcOpWrapper().lower().op().toText());
103 
104         return builder.toString();
105     }
106 }