1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.dialect.*;
 28 import jdk.incubator.code.CodeType;
 29 import optkl.IfaceValue;
 30 import optkl.OpHelper;
 31 import optkl.Trxfmr;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.dialect.core.CoreOp;
 35 import jdk.incubator.code.dialect.core.CoreType;
 36 import jdk.incubator.code.dialect.java.*;
 37 import optkl.util.ops.VarLikeOp;
 38 
 39 import java.lang.invoke.MethodHandles;
 40 import java.util.*;
 41 
 42 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
 43 import static optkl.IfaceValue.Vector.getVectorShape;
 44 import static optkl.OpHelper.*;
 45 import static optkl.OpHelper.Invoke.invoke;
 46 
 47 public record HATArrayViewPhase() implements HATPhase {
 48     static public boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
 49         if (!op.operands().isEmpty()) {
 50             CodeType type = switch(op) {
 51                 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
 52                 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
 53                 default -> OpHelper.firstOperandOrThrow(op).type();
 54             };
 55             if (type instanceof ArrayType at) {
 56                 type = at.componentType();
 57             }
 58             if (type instanceof ClassType ct) {
 59                 try {
 60                     return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
 61                 } catch (ReflectiveOperationException e) {
 62                     throw new RuntimeException(e);
 63                 }
 64             }
 65         }
 66         return false;
 67     }
 68 
 69     static public boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
 70         return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
 71     }
 72 
 73 
 74     static HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String varName, String opType, IfaceValue.Vector.Shape vectorShape, List<Value> outputOperands) {
 75         return switch (opType) {
 76             case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName,  vectorShape, outputOperands);
 77             case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName,  vectorShape, outputOperands);
 78             case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName,  vectorShape, outputOperands);
 79             case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName,  vectorShape, outputOperands);
 80             default -> throw new IllegalStateException("Unexpected value: " + opType);
 81         };
 82     }
 83     static public boolean isBufferArray(MethodHandles.Lookup lookup, Op op) {
 84         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 85         return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
 86     }
 87 
 88     static public boolean isBufferInitialize(MethodHandles.Lookup lookup, Op op) {
 89         // first check if the return is an array type
 90         if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
 91                 || op instanceof JavaOp.ArrayAccessOp
 92                 || op.resultType() instanceof ArrayType) return isBufferArray(lookup, op);
 93         return false;
 94     }
 95     static public boolean isLocalSharedOrPrivate(Op op) {
 96         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 97         return iop != null
 98                 && (iop.invokeReference().name().toLowerCase().contains("shared")
 99                 || iop.invokeReference().name().toLowerCase().contains("local")
100                 || iop.invokeReference().name().toLowerCase().contains("private")
101         );
102     }
103 
104     static public HATVectorOp buildArrayViewVector(Op op, String name, CodeType resultType, IfaceValue.Vector.Shape vectorShape, List<Value> operands) {
105         if (isLocalSharedOrPrivate(op)) {
106             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
107                 return new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(name, resultType, vectorShape, operands);
108             }
109             return new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(name, resultType, vectorShape, operands);
110         } else {
111             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
112                 return new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(name, resultType, vectorShape, operands);
113             }
114             return new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(name, resultType, vectorShape, operands);
115         }
116     }
117 
118 
119 
120     static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
121         return expressionGraph(value).getInfo(replaced);
122     }
123 
124     static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
125         return expressionGraph(new HashMap<>(), value);
126     }
127 
128     static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
129         // If value has already been visited return its node
130         if (visited.containsKey(value)) {
131             return visited.get(value);
132         }
133 
134         // Find the expression graphs for each operand
135         List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
136 
137         // looks like
138         for (Value operand : value.dependsOn()) {
139             if (operand instanceof Op.Result res &&
140                     res.op() instanceof JavaOp.InvokeOp iop
141                     && iop.invokeReference().name().toLowerCase().contains("arrayview")){ // We need to find a better way
142                 continue;
143             }
144             edges.add(expressionGraph(operand));
145         }
146         HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
147         visited.put(value, node);
148         return node;
149     }
150     @Override
151     public CoreOp.FuncOp transform(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
152         if (Invoke.stream(lookup, funcOp).noneMatch(
153                 invoke -> isBufferArray(lookup, invoke.op())
154         )) return funcOp;
155 
156         funcOp = applyArrayView(lookup,funcOp);
157 
158         if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
159                 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
160         )) funcOp = applyVectorView(lookup,funcOp);
161         return funcOp;
162     }
163 
164     public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
165         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
166             var context = blockBuilder.context();
167             switch (op) {
168                 case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof Invoke invoke -> {
169                     if (isVectorBinaryOp(invoke.lookup(), invoke)){
170                         var hatVectorBinaryOp = buildVectorBinaryOp(
171                                 invoke.varOpFromFirstUseOrThrow().varName(),
172                                 invoke.name(),// so mul, sub etc
173                                 getVectorShape(lookup,invoke.returnType()),
174                                 blockBuilder.context().getValues(invoke.op().operands())
175                         );
176                         context.mapValue(invoke.returnResult(), blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp)));
177                         return blockBuilder;
178                     }
179                 }
180                 case CoreOp.VarOp varOp -> {
181                     if (isVectorOp(lookup,varOp)) {
182                         var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
183                                 varOp.varName(),
184                                 varOp.resultType(),
185                                 getVectorShape(lookup,varOp.resultType().valueType()),
186                                 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
187                         );
188                         context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
189                         return blockBuilder;
190                     }
191                 }
192                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
193                     if (isVectorOp(lookup,arrayLoadOp)) {
194                         Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
195                         String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
196                         var resultType = CoreType.varType(arrayLoadOp.resultType());
197                         var vectorShape = getVectorShape(lookup,arrayLoadOp.resultType());
198                         List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
199                         HATVectorOp vLoadOp = buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
200                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
201                     }
202                     return blockBuilder;
203                 }
204                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
205                     if (isVectorOp(lookup,arrayStoreOp)) {
206                         Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
207                         Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
208                         String name = hatPtrName(varOp);
209                         var resultType = (varOp).resultType();
210                         var vectorShape = getVectorShape(lookup,arrayStoreOp.operands().getLast().type());
211                         List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
212                         HATVectorOp vStoreOp = buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
213                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
214                     }
215                     return blockBuilder;
216                 }
217                 default -> {
218                 }
219             }
220             blockBuilder.op(op);
221             return blockBuilder;
222         }).funcOp();
223     }
224 
225     public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
226         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
227         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
228 
229         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
230             var context = blockBuilder.context();
231             switch (op) {
232                 case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof Invoke invoke -> {
233                     if (isBufferArray(lookup, invoke.op())) { // ensures we can use iop as key for replaced vvv
234                         Op.Result result = invoke.resultFromFirstOperandOrNull();
235                         replaced.put(invoke.returnResult(), result);
236                         // map buffer VarOp to its corresponding VarLoadOp
237                         bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
238                         return blockBuilder;
239                     }
240                 }
241                 case CoreOp.VarOp varOp -> {
242                     if (isBufferInitialize(lookup, varOp)) {
243                         // makes sure we don't process a new int[] for example
244                         Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
245                         replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
246                         return blockBuilder;
247                     }
248                 }
249                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
250                     if ((isBufferInitialize(lookup, varLoadOp))) {
251                         Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
252                         Op.Result replacement;
253                         if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
254                             replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
255                                     resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
256                                     bufferVarLoads.get(replaced.get(r).op()).result();
257                         } else { // if this is a VarLoadOp loading the buffer
258                             CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
259                             replacement = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
260                             context.mapValue(varLoadOp.result(), replacement);
261                         }
262                         replaced.put(varLoadOp.result(), replacement);
263                         return blockBuilder;
264                     }
265                 }
266                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
267                     if (isBufferArray(lookup, arrayLoadOp)) {
268                         Op replacementOp=null;
269                         if (isVectorOp(lookup,arrayLoadOp)) {
270                             replacementOp = JavaOp.arrayLoadOp(
271                                     context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
272                                     context.getValue(arrayLoadOp.operands().getLast()),
273                                     arrayLoadOp.resultType()
274                             );
275                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
276                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
277                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
278                             replacementOp = new HATPtrOp.HATPtrLoadOp(
279                                     arrayAccessInfo.bufferName(),
280                                     arrayLoadOp.resultType(),
281                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
282                                     context.getValues(operands)
283                             );
284                         } else { // we only use the last array load
285                             return blockBuilder;
286                         }
287                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,replacementOp)));
288                         return blockBuilder;
289                     }
290                 }
291                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
292                     if (isBufferArray(lookup, arrayStoreOp)) {
293                         Op replacementOp;
294                         if (isVectorOp(lookup, arrayStoreOp)) {
295                             replacementOp = JavaOp.arrayStoreOp(
296                                     context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
297                                     context.getValue(arrayStoreOp.operands().get(1)),
298                                     context.getValue(arrayStoreOp.operands().getLast())
299                             );
300                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
301                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
302                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
303                             operands.add(arrayStoreOp.operands().getLast());
304                             replacementOp = new HATPtrOp.HATPtrStoreOp(
305                                     arrayAccessInfo.bufferName(),
306                                     arrayStoreOp.resultType(),
307                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
308                                     context.getValues(operands)
309                             );
310                         } else {
311                             return blockBuilder;
312                         }
313                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp, replacementOp)));
314                         return blockBuilder;
315                     }
316                 }
317                 case JavaOp.ArrayLengthOp arrayLengthOp  when
318                         isBufferArray(lookup, arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
319                     var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
320                     var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
321                             arrayAccessInfo.bufferName(),
322                             arrayLengthOp.resultType(),
323                             (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
324                             context.getValues(List.of(arrayAccessInfo.buffer()))
325                     );
326                     context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
327                     return blockBuilder;
328                 }
329                 default -> {
330                 }
331             }
332             blockBuilder.op(op);
333             return blockBuilder;
334         }).funcOp();
335     }
336 
337     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
338         public List<Value> bufferAndIndicesAsValues() {
339             List<Value> operands = new ArrayList<>(List.of(buffer));
340             operands.addAll(indices);
341             return operands;
342         }
343     }
344 
345     record Node<T>(T value, List<Node<T>> edges) {
346         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
347             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
348             Set<Node<T>> handled = new HashSet<>();
349             Op.Result buffer = null;
350             List<Op.Result> indices = new ArrayList<>();
351             while (!nodeList.isEmpty()) {
352                 Node<T> node = nodeList.removeFirst();
353                 handled.add(node);
354                 if (node.value instanceof Op.Result res &&
355                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
356                     buffer = res;
357                     // idx location differs between ArrayAccessOp and ArrayLengthOp
358                     indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
359                             ? resultFromOperandN(res.op(), 1)
360                             : resultFromOperandN(res.op(), 0)
361                     );
362                 }
363                 if (!node.edges().isEmpty()) {
364                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
365                     if (!handled.contains(next)) {
366                         nodeList.add(next);
367                     }
368                 }
369             }
370             if (buffer != null) {
371                 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
372                 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
373                 return new ArrayAccessInfo(buffer, bufferName, indices);
374             } else {
375                 return null;
376             }
377         }
378     }
379 
380     public static String hatPtrName(Op op) {
381         return switch (op) {
382             case CoreOp.VarOp varOp -> varOp.varName();
383             case VarLikeOp varLikeOp -> varLikeOp.varName();
384             case null, default -> "";
385         };
386     }
387 }