1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.Accelerator;
 28 import hat.dialect.HATVectorSelectLoadOp;
 29 import hat.dialect.HATVectorSelectStoreOp;
 30 import hat.dialect.HATVectorViewOp;
 31 import hat.optools.OpTk;
 32 import jdk.incubator.code.CodeElement;
 33 import jdk.incubator.code.CopyContext;
 34 import jdk.incubator.code.Op;
 35 import jdk.incubator.code.Value;
 36 import jdk.incubator.code.dialect.core.CoreOp;
 37 import jdk.incubator.code.dialect.java.JavaOp;
 38 import jdk.incubator.code.dialect.java.JavaType;
 39 
 40 import java.util.List;
 41 import java.util.Set;
 42 import java.util.stream.Collectors;
 43 import java.util.stream.Stream;
 44 
 45 public class HATDialectifyVectorSelectPhase implements HATDialect {
 46 
 47     protected final Accelerator accelerator;
 48     @Override  public Accelerator accelerator(){
 49         return this.accelerator;
 50     }
 51     public HATDialectifyVectorSelectPhase(Accelerator accelerator) {
 52         this.accelerator = accelerator;
 53     }
 54 
 55     private boolean isVectorLane(JavaOp.InvokeOp invokeOp) {
 56         return isMethod(invokeOp, "x")
 57                 || isMethod(invokeOp, "y")
 58                 || isMethod(invokeOp, "z")
 59                 || isMethod(invokeOp, "w");
 60     }
 61 
 62     int getLane(String fieldName) {
 63         return switch (fieldName) {
 64             case "x" -> 0;
 65             case "y" -> 1;
 66             case "z" -> 2;
 67             case "w" -> 3;
 68             default -> -1;
 69         };
 70     }
 71 
 72     private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) {
 73         String invokeClass = invokeOp.invokeDescriptor().refType().toString();
 74         boolean isHatVectorType = invokeClass.startsWith("hat.buffer.Float");
 75         return isHatVectorType
 76                 && OpTk.isIfaceBufferMethod(accelerator.lookup, invokeOp)
 77                 && (isVectorLane(invokeOp));
 78     }
 79 
 80     private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 81         return findNameVector(varLoadOp.operands().get(0));
 82     }
 83 
 84     private String findNameVector(Value v) {
 85         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 86             return findNameVector(varLoadOp);
 87         } else {
 88             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp vectorViewOp) {
 89                 return vectorViewOp.varName();
 90             }
 91             return null;
 92         }
 93     }
 94 
 95     private CoreOp.VarOp findVarOp(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 96         return findVarOp(varLoadOp.operands().get(0));
 97     }
 98 
 99     private CoreOp.VarOp findVarOp(Value v) {
100         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
101             return findVarOp(varLoadOp);
102         } else {
103             if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.VarOp varOp) {
104                 return varOp;
105             }
106             return null;
107         }
108     }
109 
110 
111     // Code Model Pattern:
112     //  %16 : java.type:"hat.buffer.Float4" = var.load %15 @loc="63:28";
113     //  %17 : java.type:"float" = invoke %16 @loc="63:28" @java.ref:"hat.buffer.Float4::x():float";
114 
115     private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) {
116         var here = OpTk.CallSite.of(this.getClass(), "vloadSelectPhase");
117         before(here, funcOp);
118         Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements()
119                 .mapMulti((codeElement, consumer) -> {
120                     if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
121                         if (isVectorOperation(invokeOp) && invokeOp.resultType() != JavaType.VOID) {
122                             List<Value> inputOperandsInvoke = invokeOp.operands();
123                             Value inputOperand = inputOperandsInvoke.getFirst();
124                             if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
125                                 consumer.accept(invokeOp);
126                                 consumer.accept(varLoadOp);
127                             }
128                         }
129                     }
130                 });
131 
132         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
133 
134            funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> {
135             CopyContext context = blockBuilder.context();
136             if (!nodesInvolved.contains(op)) {
137                 blockBuilder.op(op);
138             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
139                 List<Value> inputInvokeOp = invokeOp.operands();
140                 for (Value v : inputInvokeOp) {
141                     if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
142                         List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
143                         int lane = getLane(invokeOp.invokeDescriptor().name());
144                         HATVectorViewOp vSelectOp;
145                         String name = findNameVector(varLoadOp);
146                         if (invokeOp.resultType() != JavaType.VOID) {
147                             vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp);
148                         } else {
149                             throw new RuntimeException("VSelect Load Op must return a value!");
150                         }
151                         Op.Result hatSelectResult = blockBuilder.op(vSelectOp);
152                         vSelectOp.setLocation(invokeOp.location());
153                         context.mapValue(invokeOp.result(), hatSelectResult);
154                     }
155                 }
156             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
157                 // Pass the value
158                 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
159             }
160             return blockBuilder;
161         });
162 
163        after(here,funcOp);
164         return funcOp;
165     }
166 
167     // Pattern from the code mode:
168     // %20 : java.type:"hat.buffer.Float4" = var.load %15 @loc="64:13";
169     // %21 : java.type:"float" = var.load %19 @loc="64:18";
170     // invoke %20 %21 @loc="64:13" @java.ref:"hat.buffer.Float4::x(float):void";
171     private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) {
172         var here = OpTk.CallSite.of(this.getClass(),"vstoreSelectPhase");
173          before(here, funcOp);
174           //TODO is this side table safe?
175         Stream<CodeElement<?, ?>> float4NodesInvolved = OpTk.elements(here,funcOp)
176                 .mapMulti((codeElement, consumer) -> {
177                     if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
178                         if (isVectorOperation(invokeOp)) {
179                             List<Value> inputOperandsInvoke = invokeOp.operands();
180                             Value inputOperand = inputOperandsInvoke.getFirst();
181                             if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
182                                 consumer.accept(invokeOp);
183                                 consumer.accept(varLoadOp);
184                             }
185                         }
186                     }
187                 });
188 
189         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
190         funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
191             CopyContext context = blockBuilder.context();
192             if (!nodesInvolved.contains(op)) {
193                 blockBuilder.op(op);
194             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
195                 List<Value> inputInvokeOp = invokeOp.operands();
196                 Value v = inputInvokeOp.getFirst();
197 
198                 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
199                     List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
200                     int lane = getLane(invokeOp.invokeDescriptor().name());
201                     HATVectorViewOp vSelectOp;
202                     String name = findNameVector(varLoadOp);
203                     if (invokeOp.resultType() == JavaType.VOID) {
204                         // The operand 1 in the store is the address (lane)
205                         // The operand 1 in the store is the storeValue
206                         CoreOp.VarOp resultOp = findVarOp(outputOperandsInvokeOp.get(1));
207                         vSelectOp = new HATVectorSelectStoreOp(name, invokeOp.resultType(), lane, resultOp, outputOperandsInvokeOp);
208                     } else {
209                         throw new RuntimeException("VSelect Store Op must return a value!");
210                     }
211                     Op.Result resultVStore = blockBuilder.op(vSelectOp);
212                     vSelectOp.setLocation(invokeOp.location());
213                     context.mapValue(invokeOp.result(), resultVStore);
214                 }
215 
216             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
217                 // Pass the value
218                 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
219             }
220             return blockBuilder;
221         });
222 
223         after(here, funcOp);
224         return funcOp;
225     }
226 
227     @Override
228     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
229         funcOp = vloadSelectPhase(funcOp);
230         funcOp = vstoreSelectPhase(funcOp);
231         return funcOp;
232     }
233 
234 }