1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.Accelerator;
 28 import hat.NDRange;
 29 import hat.dialect.HATPhaseUtils;
 30 import hat.dialect.HATVectorSelectLoadOp;
 31 import hat.dialect.HATVectorSelectStoreOp;
 32 import hat.dialect.HATVectorOp;
 33 import hat.optools.OpTk;
 34 import hat.types._V;
 35 import jdk.incubator.code.CodeElement;
 36 import jdk.incubator.code.CopyContext;
 37 import jdk.incubator.code.Op;
 38 import jdk.incubator.code.Value;
 39 import jdk.incubator.code.dialect.core.CoreOp;
 40 import jdk.incubator.code.dialect.java.JavaOp;
 41 import jdk.incubator.code.dialect.java.JavaType;
 42 
 43 import java.util.List;
 44 import java.util.Set;
 45 import java.util.stream.Collectors;
 46 import java.util.stream.Stream;
 47 
 48 public class HATDialectifyVectorSelectPhase implements HATDialect {
 49 
 50     protected final Accelerator accelerator;
 51     @Override  public Accelerator accelerator(){
 52         return this.accelerator;
 53     }
 54     public HATDialectifyVectorSelectPhase(Accelerator accelerator) {
 55         this.accelerator = accelerator;
 56     }
 57 
 58     private boolean isVectorLane(JavaOp.InvokeOp invokeOp) {
 59         return isMethod(invokeOp, "x")
 60                 || isMethod(invokeOp, "y")
 61                 || isMethod(invokeOp, "z")
 62                 || isMethod(invokeOp, "w");
 63     }
 64 
 65     int getLane(String fieldName) {
 66         return switch (fieldName) {
 67             case "x" -> 0;
 68             case "y" -> 1;
 69             case "z" -> 2;
 70             case "w" -> 3;
 71             default -> -1;
 72         };
 73     }
 74 
 75     private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) {
 76         String typeElement = invokeOp.invokeDescriptor().refType().toString();
 77         Set<Class<?>> interfaces;
 78         try {
 79             Class<?> aClass = Class.forName(typeElement);
 80             interfaces = HATPhaseUtils.inspectAllInterfaces(aClass);
 81         } catch (ClassNotFoundException _) {
 82             return false;
 83         }
 84         return interfaces.contains(_V.class) && isVectorLane(invokeOp);
 85     }
 86 
 87     private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 88         return findNameVector(varLoadOp.operands().get(0));
 89     }
 90 
 91     private String findNameVector(Value v) {
 92         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 93             return findNameVector(varLoadOp);
 94         } else {
 95             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp vectorViewOp) {
 96                 return vectorViewOp.varName();
 97             }
 98             return null;
 99         }
100     }
101 
102     private CoreOp.VarOp findVarOp(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
103         return findVarOp(varLoadOp.operands().get(0));
104     }
105 
106     private CoreOp.VarOp findVarOp(Value v) {
107         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
108             return findVarOp(varLoadOp);
109         } else {
110             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.VarOp varOp) {
111                 return varOp;
112             }
113             return null;
114         }
115     }
116 
117     // Code Model Pattern:
118     //  %16 : java.type:"hat.buffer.Float4" = var.load %15 @loc="63:28";
119     //  %17 : java.type:"float" = invoke %16 @loc="63:28" @java.ref:"hat.buffer.Float4::x():float";
120     private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) {
121         var here = OpTk.CallSite.of(this.getClass(), "vloadSelectPhase");
122         before(here, funcOp);
123         Stream<CodeElement<?, ?>> vectorSelectOps = funcOp.elements()
124                 .mapMulti((codeElement, consumer) -> {
125                     if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
126                         if (isVectorOperation(invokeOp) && (invokeOp.resultType() != JavaType.VOID)) {
127                             Value inputOperand = invokeOp.operands().getFirst();
128                             if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
129                                 consumer.accept(invokeOp);
130                                 consumer.accept(varLoadOp);
131                             }
132                         }
133                     }
134                 });
135 
136         Set<CodeElement<?, ?>> nodesInvolved = vectorSelectOps.collect(Collectors.toSet());
137         funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
138             CopyContext context = blockBuilder.context();
139             if (!nodesInvolved.contains(op)) {
140                 blockBuilder.op(op);
141             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
142                 List<Value> inputInvokeOp = invokeOp.operands();
143                 for (Value v : inputInvokeOp) {
144                     if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
145                         List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
146                         int lane = getLane(invokeOp.invokeDescriptor().name());
147                         HATVectorOp vSelectOp;
148                         String name = findNameVector(varLoadOp);
149                         if (invokeOp.resultType() != JavaType.VOID) {
150                             vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp);
151                         } else {
152                             throw new RuntimeException("VSelect Load Op must return a value!");
153                         }
154                         Op.Result hatSelectResult = blockBuilder.op(vSelectOp);
155                         vSelectOp.setLocation(invokeOp.location());
156                         context.mapValue(invokeOp.result(), hatSelectResult);
157                     }
158                 }
159             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
160                 // Pass the value
161                 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
162             }
163             return blockBuilder;
164         });
165 
166         after(here, funcOp);
167         return funcOp;
168     }
169 
170     // Pattern from the code mode:
171     // %20 : java.type:"hat.buffer.Float4" = var.load %15 @loc="64:13";
172     // %21 : java.type:"float" = var.load %19 @loc="64:18";
173     // invoke %20 %21 @loc="64:13" @java.ref:"hat.buffer.Float4::x(float):void";
174     private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) {
175         var here = OpTk.CallSite.of(this.getClass(),"vstoreSelectPhase");
176          before(here, funcOp);
177           //TODO is this side table safe?
178         Stream<CodeElement<?, ?>> float4NodesInvolved = OpTk.elements(here,funcOp)
179                 .mapMulti((codeElement, consumer) -> {
180                     if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
181                         if (isVectorOperation(invokeOp)) {
182                             List<Value> inputOperandsInvoke = invokeOp.operands();
183                             Value inputOperand = inputOperandsInvoke.getFirst();
184                             if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
185                                 consumer.accept(invokeOp);
186                                 consumer.accept(varLoadOp);
187                             }
188                         }
189                     }
190                 });
191 
192         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
193         funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
194             CopyContext context = blockBuilder.context();
195             if (!nodesInvolved.contains(op)) {
196                 blockBuilder.op(op);
197             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
198                 List<Value> inputInvokeOp = invokeOp.operands();
199                 Value v = inputInvokeOp.getFirst();
200 
201                 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
202                     List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
203                     int lane = getLane(invokeOp.invokeDescriptor().name());
204                     HATVectorOp vSelectOp;
205                     String name = findNameVector(varLoadOp);
206                     if (invokeOp.resultType() == JavaType.VOID) {
207                         // The operand 1 in the store is the address (lane)
208                         // The operand 1 in the store is the storeValue
209                         CoreOp.VarOp resultOp = findVarOp(outputOperandsInvokeOp.get(1));
210                         vSelectOp = new HATVectorSelectStoreOp(name, invokeOp.resultType(), lane, resultOp, outputOperandsInvokeOp);
211                     } else {
212                         throw new RuntimeException("VSelect Store Op must return a value!");
213                     }
214                     Op.Result resultVStore = blockBuilder.op(vSelectOp);
215                     vSelectOp.setLocation(invokeOp.location());
216                     context.mapValue(invokeOp.result(), resultVStore);
217                 }
218 
219             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
220                 // Pass the value
221                 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
222             }
223             return blockBuilder;
224         });
225 
226         after(here, funcOp);
227         return funcOp;
228     }
229 
230     @Override
231     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
232         funcOp = vloadSelectPhase(funcOp);
233         funcOp = vstoreSelectPhase(funcOp);
234         return funcOp;
235     }
236 
237 }