1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.Accelerator;
 28 import hat.dialect.HATLocalVarOp;
 29 import hat.dialect.HATPrivateVarOp;
 30 import hat.dialect.HATVectorAddOp;
 31 import hat.dialect.HATVectorDivOp;
 32 import hat.dialect.HATVectorLoadOp;
 33 import hat.dialect.HATVectorMulOp;
 34 import hat.dialect.HATVectorSubOp;
 35 import hat.dialect.HATVectorVarLoadOp;
 36 import hat.dialect.HATVectorVarOp;
 37 import hat.dialect.HATVectorViewOp;
 38 import hat.dialect.HATVectorBinaryOp;
 39 import hat.optools.OpTk;
 40 import jdk.incubator.code.CodeElement;
 41 import jdk.incubator.code.CopyContext;
 42 import jdk.incubator.code.Op;
 43 import jdk.incubator.code.TypeElement;
 44 import jdk.incubator.code.Value;
 45 import jdk.incubator.code.dialect.core.CoreOp;
 46 import jdk.incubator.code.dialect.java.JavaOp;
 47 
 48 import java.util.HashMap;
 49 import java.util.List;
 50 import java.util.Map;
 51 import java.util.Objects;
 52 import java.util.Set;
 53 import java.util.stream.Collectors;
 54 import java.util.stream.Stream;
 55 
 56 public abstract class HATDialectifyVectorOpPhase implements HATDialect{
 57 
 58     protected final Accelerator accelerator;
 59     @Override  public Accelerator accelerator(){
 60         return this.accelerator;
 61     }
 62     private final OpView vectorOperation;
 63 
 64     public HATDialectifyVectorOpPhase(Accelerator accelerator, OpView vectorOperation) {
 65        this.accelerator = accelerator;
 66         this.vectorOperation = vectorOperation;
 67     }
 68 
 69 
 70 
 71     private HATVectorBinaryOp.OpType getBinaryOpType(JavaOp.InvokeOp invokeOp) {
 72         return switch (invokeOp.invokeDescriptor().name()) {
 73             case "add" -> HATVectorBinaryOp.OpType.ADD;
 74             case "sub" -> HATVectorBinaryOp.OpType.SUB;
 75             case "mul" -> HATVectorBinaryOp.OpType.MUL;
 76             case "div" -> HATVectorBinaryOp.OpType.DIV;
 77             default -> throw new RuntimeException("Unknown binary op " + invokeOp.invokeDescriptor().name());
 78         };
 79     }
 80 
 81     public enum OpView {
 82         FLOAT4_LOAD("float4View"),
 83         ADD("add"),
 84         SUB("sub"),
 85         MUL("mul"),
 86         DIV("div");
 87         final String methodName;
 88         OpView(String methodName) {
 89             this.methodName = methodName;
 90         }
 91     }
 92 
 93     private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) {
 94         TypeElement typeElement = invokeOp.resultType();
 95         boolean isHatVectorType = typeElement.toString().startsWith("hat.buffer.Float");
 96         return isHatVectorType
 97                 && OpTk.isIfaceBufferMethod(accelerator.lookup, invokeOp)
 98                 && isMethod(invokeOp, vectorOperation.methodName);
 99     }
100 
101     private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
102         return findNameVector(varLoadOp.operands().get(0));
103     }
104 
105     private String findNameVector(Value v) {
106         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
107             return findNameVector(varLoadOp);
108         } else {
109             // Leaf of tree -
110             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) {
111                 return hatVectorViewOp.varName();
112             }
113             return null;
114         }
115     }
116 
117     private boolean findIsSharedOrPrivate(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
118         return findIsSharedOrPrivate(varLoadOp.operands().get(0));
119     }
120 
121     private boolean findIsSharedOrPrivate(Value v) {
122         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
123             return findIsSharedOrPrivate(varLoadOp);
124         } else {
125             // Leaf of tree -
126             if (v instanceof CoreOp.Result r && (r.op() instanceof HATLocalVarOp || r.op() instanceof HATPrivateVarOp)) {
127                 return true;
128             }
129             return false;
130         }
131     }
132 
133     private HATVectorBinaryOp buildVectorBinaryOp(HATVectorBinaryOp.OpType opType, String varName, TypeElement resultType, List<Value> outputOperands) {
134         return switch (opType) {
135             case ADD -> new HATVectorAddOp(varName, resultType, outputOperands);
136             case SUB -> new HATVectorSubOp(varName, resultType, outputOperands);
137             case MUL -> new HATVectorMulOp(varName, resultType, outputOperands);
138             case DIV -> new HATVectorDivOp(varName, resultType, outputOperands);
139         };
140     }
141 
142     private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) {
143         var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorLoad" );
144         before(here,funcOp);
145             Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements()
146                 .mapMulti((codeElement, consumer) -> {
147                     if (codeElement instanceof CoreOp.VarOp varOp) {
148                         List<Value> inputOperandsVarOp = varOp.operands();
149                         for (Value inputOperand : inputOperandsVarOp) {
150                             if (inputOperand instanceof Op.Result result) {
151                                 if (result.op() instanceof JavaOp.InvokeOp invokeOp) {
152                                     if (isVectorOperation(invokeOp)) {
153                                         consumer.accept(invokeOp);
154                                         consumer.accept(varOp);
155                                     }
156                                 }
157                             }
158                         }
159                     }
160                 });
161 
162         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
163 
164 
165         funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> {
166             CopyContext context = blockBuilder.context();
167             if (!nodesInvolved.contains(op)) {
168                 blockBuilder.op(op);
169             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
170                 // Don't insert the invoke node
171                 Op.Result result = invokeOp.result();
172                 List<Op.Result> collect = result.uses().stream().toList();
173                 boolean isShared = findIsSharedOrPrivate(invokeOp.operands().getFirst());
174                 for (Op.Result r : collect) {
175                     if (r.op() instanceof CoreOp.VarOp varOp) {
176                         List<Value> inputOperandsVarOp = invokeOp.operands();
177                         List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp);
178                         HATVectorViewOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp);
179                         Op.Result hatLocalResult = blockBuilder.op(memoryViewOp);
180                         memoryViewOp.setLocation(varOp.location());
181                         context.mapValue(invokeOp.result(), hatLocalResult);
182                     }
183                 }
184             } else if (op instanceof CoreOp.VarOp varOp) {
185                 // pass value
186                 //context.mapValue(varOp.result(), context.getValue(varOp.operands().getFirst()));
187                 List<Value> inputOperandsVarOp = varOp.operands();
188                 List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp);
189                 HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp);
190                 Op.Result hatLocalResult = blockBuilder.op(memoryViewOp);
191                 memoryViewOp.setLocation(varOp.location());
192                 context.mapValue(varOp.result(), hatLocalResult);
193             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
194                 // pass value
195                 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
196             }
197             return blockBuilder;
198         });
199         after(here, funcOp);
200         return funcOp;
201     }
202 
203     private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) {
204         var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorBinaryOps");
205         before(here, funcOp);
206         Map<JavaOp.InvokeOp, HATVectorBinaryOp.OpType> binaryOperation = new HashMap<>();
207 
208         Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements()
209                 .mapMulti((codeElement, consumer) -> {
210                     if (codeElement instanceof CoreOp.VarOp varOp) {
211                         List<Value> inputOperandsVarOp = varOp.operands();
212                         for (Value inputOperand : inputOperandsVarOp) {
213                             if (inputOperand instanceof Op.Result result) {
214                                 if (result.op() instanceof JavaOp.InvokeOp invokeOp) {
215                                     if (isVectorOperation(invokeOp)) {
216                                         HATVectorBinaryOp.OpType binaryOpType = getBinaryOpType(invokeOp);
217                                         binaryOperation.put(invokeOp, binaryOpType);
218                                         consumer.accept(invokeOp);
219                                         consumer.accept(varOp);
220                                     }
221                                 }
222                             }
223                         }
224                     }
225                 });
226 
227         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
228 
229         funcOp = OpTk.transform(here, funcOp, nodesInvolved::contains, (blockBuilder, op) -> {
230             CopyContext context = blockBuilder.context();
231            // if (!nodesInvolved.contains(op)) {
232              //   blockBuilder.op(op);
233             //} else
234                 if (op instanceof JavaOp.InvokeOp invokeOp) {
235                 Op.Result result = invokeOp.result();
236                 List<Value> inputOperands = invokeOp.operands();
237                 List<Value> outputOperands = context.getValues(inputOperands);
238                 List<Op.Result> collect = result.uses().stream().toList();
239                 for (Op.Result r : collect) {
240                     if (r.op() instanceof CoreOp.VarOp varOp) {
241                         HATVectorBinaryOp.OpType binaryOpType = binaryOperation.get(invokeOp);
242                         HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands);
243                         Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp);
244                         memoryViewOp.setLocation(varOp.location());
245                         context.mapValue(invokeOp.result(), hatVectorOpResult);
246                         break;
247                     }
248                 }
249             } else if (op instanceof CoreOp.VarOp varOp) {
250                 List<Value> inputOperandsVarOp = varOp.operands();
251                 List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp);
252                 HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp);
253                 Op.Result hatVectorResult = blockBuilder.op(memoryViewOp);
254                 memoryViewOp.setLocation(varOp.location());
255                 context.mapValue(varOp.result(), hatVectorResult);
256             }
257             return blockBuilder;
258         });
259        after(here,funcOp);
260         return funcOp;
261     }
262 
263     private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp funcOp) {
264         var here = OpTk.CallSite.of(this.getClass(), "dialectifyBinaryWithConcatenation");
265         before(here, funcOp);
266         Map<JavaOp.InvokeOp, HATVectorBinaryOp.OpType> binaryOperation = new HashMap<>();
267         Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements()
268                 .mapMulti((codeElement, consumer) -> {
269                     if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
270                         if (isVectorOperation(invokeOp)) {
271                             List<Value> inputOperandsInvoke = invokeOp.operands();
272                             for (Value inputOperand : inputOperandsInvoke) {
273                                 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
274                                     HATVectorBinaryOp.OpType binaryOpType = getBinaryOpType(invokeOp);
275                                     binaryOperation.put(invokeOp, binaryOpType);
276                                     consumer.accept(varLoadOp);
277                                     consumer.accept(invokeOp);
278                                 }
279                             }
280                         }
281                     } else if (codeElement instanceof HATVectorBinaryOp hatVectorBinaryOp) {
282                         List<Value> inputOperandsInvoke = hatVectorBinaryOp.operands();
283                         for (Value inputOperand : inputOperandsInvoke) {
284                             if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
285                                 consumer.accept(varLoadOp);
286                             }
287                         }
288                     }
289                 });
290 
291         Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
292         if (nodesInvolved.isEmpty()) {
293             return funcOp;
294         }
295          funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
296             CopyContext context = blockBuilder.context();
297             if (!nodesInvolved.contains(op)) {
298                 blockBuilder.op(op);
299             } else if (op instanceof JavaOp.InvokeOp invokeOp) {
300                 List<Value> inputOperands = invokeOp.operands();
301                 List<Value> outputOperands = context.getValues(inputOperands);
302                 HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands);
303                 Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp);
304                 memoryViewOp.setLocation(invokeOp.location());
305                 context.mapValue(invokeOp.result(), hatVectorOpResult);
306             } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
307                 List<Value> inputOperandsVarLoad = varLoadOp.operands();
308                 List<Value> outputOperandsVarLoad = context.getValues(inputOperandsVarLoad);
309                 String varLoadName = findNameVector(varLoadOp);
310                 HATVectorViewOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad);
311                 Op.Result hatVectorResult = blockBuilder.op(memoryViewOp);
312                 memoryViewOp.setLocation(varLoadOp.location());
313                 context.mapValue(varLoadOp.result(), hatVectorResult);
314             }
315             return blockBuilder;
316         });
317       after(here,funcOp);
318         return funcOp;
319     }
320 
321     @Override
322     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
323         if (Objects.requireNonNull(vectorOperation) == OpView.FLOAT4_LOAD) {
324             funcOp = dialectifyVectorLoad(funcOp);
325         } else {
326             // Find binary operations
327             funcOp = dialectifyVectorBinaryOps(funcOp);
328             funcOp = dialectifyVectorBinaryWithContatenationOps(funcOp);
329         }
330         return funcOp;
331     }
332 
333     public static class AddPhase extends HATDialectifyVectorOpPhase{
334 
335         public AddPhase(Accelerator accelerator) {
336            super(accelerator, OpView.ADD);
337         }
338     }
339 
340     public static class DivPhase extends HATDialectifyVectorOpPhase{
341 
342         public DivPhase(Accelerator accelerator) {
343            super(accelerator, OpView.DIV);
344         }
345     }
346 
347     public static class Float4LoadPhase extends HATDialectifyVectorOpPhase{
348 
349         public Float4LoadPhase(Accelerator accelerator) {
350            super(accelerator, OpView.FLOAT4_LOAD);
351         }
352     }
353 
354     public static class MulPhase extends HATDialectifyVectorOpPhase{
355 
356         public MulPhase(Accelerator accelerator) {
357            super(accelerator, OpView.MUL);
358         }
359     }
360 
361     public static class SubPhase extends HATDialectifyVectorOpPhase{
362 
363         public SubPhase(Accelerator accelerator) {
364            super(accelerator, OpView.SUB);
365         }
366     }
367 }