1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.dialect.*;
 28 import jdk.incubator.code.CodeType;
 29 import optkl.IfaceValue;
 30 import optkl.OpHelper;
 31 import optkl.Trxfmr;
 32 import jdk.incubator.code.Op;
 33 import jdk.incubator.code.Value;
 34 import jdk.incubator.code.dialect.core.CoreOp;
 35 import jdk.incubator.code.dialect.core.CoreType;
 36 import jdk.incubator.code.dialect.java.*;
 37 import optkl.util.ops.VarLikeOp;
 38 
 39 import java.lang.invoke.MethodHandles;
 40 import java.util.*;
 41 
 42 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
 43 import static optkl.IfaceValue.Vector.getVectorShape;
 44 import static optkl.OpHelper.*;
 45 import static optkl.OpHelper.Invoke.invoke;
 46 
 47 public record HATArrayViewPhase() implements HATPhase {
 48     public static boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
 49         if (!op.operands().isEmpty()) {
 50             CodeType type = switch(op) {
 51                 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
 52                 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
 53                 default -> OpHelper.firstOperandOrThrow(op).type();
 54             };
 55             if (type instanceof ArrayType at) {
 56                 type = at.componentType();
 57             }
 58             if (type instanceof ClassType ct) {
 59                 try {
 60                     return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
 61                 } catch (ReflectiveOperationException e) {
 62                     throw new RuntimeException(e);
 63                 }
 64             }
 65         }
 66         return false;
 67     }
 68 
 69     public static boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
 70         return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
 71     }
 72 
 73 
 74     static HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String varName, String opType, IfaceValue.Vector.Shape vectorShape, List<Value> outputOperands) {
 75         return switch (opType) {
 76             case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName,  vectorShape, outputOperands);
 77             case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName,  vectorShape, outputOperands);
 78             case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName,  vectorShape, outputOperands);
 79             case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName,  vectorShape, outputOperands);
 80             default -> throw new IllegalStateException("Unexpected value: " + opType);
 81         };
 82     }
 83 
 84     public static boolean isBufferArray(MethodHandles.Lookup lookup, Op op) {
 85         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 86         return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
 87     }
 88 
 89     public static boolean isBufferInitialize(MethodHandles.Lookup lookup, Op op) {
 90         // first check if the return is an array type
 91         if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
 92                 || op instanceof JavaOp.ArrayAccessOp
 93                 || op.resultType() instanceof ArrayType) return isBufferArray(lookup, op);
 94         return false;
 95     }
 96 
 97     public static boolean isLocalSharedOrPrivate(Op op) {
 98         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 99         return iop != null
100                 && (iop.invokeReference().name().toLowerCase().contains("shared")
101                 || iop.invokeReference().name().toLowerCase().contains("local")
102                 || iop.invokeReference().name().toLowerCase().contains("private")
103         );
104     }
105 
106     public static HATVectorOp buildArrayViewVector(Op op, String name, CodeType resultType, IfaceValue.Vector.Shape vectorShape, List<Value> operands) {
107         if (isLocalSharedOrPrivate(op)) {
108             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
109                 return new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(name, resultType, vectorShape, operands);
110             }
111             return new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(name, resultType, vectorShape, operands);
112         } else {
113             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
114                 return new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(name, resultType, vectorShape, operands);
115             }
116             return new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(name, resultType, vectorShape, operands);
117         }
118     }
119 
120 
121 
122     static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
123         return expressionGraph(value).getInfo(replaced);
124     }
125 
126     static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
127         return expressionGraph(new HashMap<>(), value);
128     }
129 
130     static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
131         // If value has already been visited return its node
132         if (visited.containsKey(value)) {
133             return visited.get(value);
134         }
135 
136         // Find the expression graphs for each operand
137         List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
138 
139         // looks like
140         for (Value operand : value.dependsOn()) {
141             if (operand instanceof Op.Result res &&
142                     res.op() instanceof JavaOp.InvokeOp iop
143                     && iop.invokeReference().name().toLowerCase().contains("arrayview")){ // We need to find a better way
144                 continue;
145             }
146             edges.add(expressionGraph(operand));
147         }
148         HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
149         visited.put(value, node);
150         return node;
151     }
152     @Override
153     public CoreOp.FuncOp transform(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
154         if (Invoke.stream(lookup, funcOp).noneMatch(
155                 invoke -> isBufferArray(lookup, invoke.op())
156         )) return funcOp;
157 
158         funcOp = applyArrayView(lookup,funcOp);
159 
160         if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
161                 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
162         )) funcOp = applyVectorView(lookup,funcOp);
163         return funcOp;
164     }
165 
166     public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
167         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
168             var context = blockBuilder.context();
169             switch (op) {
170                 case JavaOp.InvokeOp iOp when invoke(lookup, iOp) instanceof Invoke invoke -> {
171                     if (isVectorBinaryOp(invoke.lookup(), invoke)){
172                         var hatVectorBinaryOp = buildVectorBinaryOp(
173                                 invoke.varOpFromFirstUseOrThrow().varName(),
174                                 invoke.name(),// so mul, sub etc
175                                 getVectorShape(lookup,invoke.returnType()),
176                                 blockBuilder.context().getValues(invoke.op().operands())
177                         );
178                         context.mapValue(invoke.returnResult(), blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp)));
179                         return blockBuilder;
180                     }
181                 }
182                 case CoreOp.VarOp varOp -> {
183                     if (isVectorOp(lookup,varOp)) {
184                         var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
185                                 varOp.varName(),
186                                 varOp.resultType(),
187                                 getVectorShape(lookup,varOp.resultType().valueType()),
188                                 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
189                         );
190                         context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
191                         return blockBuilder;
192                     }
193                 }
194                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
195                     if (isVectorOp(lookup,arrayLoadOp)) {
196                         Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
197                         String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
198                         var resultType = CoreType.varType(arrayLoadOp.resultType());
199                         var vectorShape = getVectorShape(lookup,arrayLoadOp.resultType());
200                         List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
201                         HATVectorOp vLoadOp = buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
202                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
203                     }
204                     return blockBuilder;
205                 }
206                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
207                     if (isVectorOp(lookup,arrayStoreOp)) {
208                         Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
209                         Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
210                         String name = hatPtrName(varOp);
211                         var resultType = (varOp).resultType();
212                         var vectorShape = getVectorShape(lookup,arrayStoreOp.operands().getLast().type());
213                         List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
214                         HATVectorOp vStoreOp = buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
215                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
216                     }
217                     return blockBuilder;
218                 }
219                 default -> {
220                 }
221             }
222             blockBuilder.op(op);
223             return blockBuilder;
224         }).funcOp();
225     }
226 
227     public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
228         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
229         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
230 
231         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
232             var context = blockBuilder.context();
233             switch (op) {
234                 case JavaOp.InvokeOp invokeOp when invoke(lookup, invokeOp) instanceof Invoke invoke -> {
235                     if (isBufferArray(lookup, invoke.op())) { // ensures we can use iop as key for replaced vvv
236                         Op.Result result = invoke.resultFromFirstOperandOrNull();
237                         replaced.put(invoke.returnResult(), result);
238                         // map buffer VarOp to its corresponding VarLoadOp
239                         bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
240                         return blockBuilder;
241                     }
242                 }
243                 case CoreOp.VarOp varOp -> {
244                     if (isBufferInitialize(lookup, varOp)) {
245                         // makes sure we don't process a new int[] for example
246                         Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
247                         replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
248                         return blockBuilder;
249                     }
250                 }
251                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
252                     if ((isBufferInitialize(lookup, varLoadOp))) {
253                         Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
254                         Op.Result replacement;
255                         if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
256                             replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
257                                     resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
258                                     bufferVarLoads.get(replaced.get(r).op()).result();
259                         } else { // if this is a VarLoadOp loading the buffer
260                             CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
261                             replacement = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
262                             context.mapValue(varLoadOp.result(), replacement);
263                         }
264                         replaced.put(varLoadOp.result(), replacement);
265                         return blockBuilder;
266                     }
267                 }
268                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
269                     if (isBufferArray(lookup, arrayLoadOp)) {
270                         Op replacementOp=null;
271                         if (isVectorOp(lookup,arrayLoadOp)) {
272                             replacementOp = JavaOp.arrayLoadOp(
273                                     context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
274                                     context.getValue(arrayLoadOp.operands().getLast()),
275                                     arrayLoadOp.resultType()
276                             );
277                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
278                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
279                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
280                             replacementOp = new HATPtrOp.HATPtrLoadOp(
281                                     arrayAccessInfo.bufferName(),
282                                     arrayLoadOp.resultType(),
283                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
284                                     context.getValues(operands)
285                             );
286                         } else { // we only use the last array load
287                             return blockBuilder;
288                         }
289                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,replacementOp)));
290                         return blockBuilder;
291                     }
292                 }
293                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
294                     if (isBufferArray(lookup, arrayStoreOp)) {
295                         Op replacementOp;
296                         if (isVectorOp(lookup, arrayStoreOp)) {
297                             replacementOp = JavaOp.arrayStoreOp(
298                                     context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
299                                     context.getValue(arrayStoreOp.operands().get(1)),
300                                     context.getValue(arrayStoreOp.operands().getLast())
301                             );
302                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
303                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
304                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
305                             operands.add(arrayStoreOp.operands().getLast());
306                             replacementOp = new HATPtrOp.HATPtrStoreOp(
307                                     arrayAccessInfo.bufferName(),
308                                     arrayStoreOp.resultType(),
309                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
310                                     context.getValues(operands)
311                             );
312                         } else {
313                             return blockBuilder;
314                         }
315                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp, replacementOp)));
316                         return blockBuilder;
317                     }
318                 }
319                 case JavaOp.ArrayLengthOp arrayLengthOp  when
320                         isBufferArray(lookup, arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
321                     var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
322                     var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
323                             arrayAccessInfo.bufferName(),
324                             arrayLengthOp.resultType(),
325                             (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
326                             context.getValues(List.of(arrayAccessInfo.buffer()))
327                     );
328                     context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
329                     return blockBuilder;
330                 }
331                 default -> {
332                 }
333             }
334             blockBuilder.op(op);
335             return blockBuilder;
336         }).funcOp();
337     }
338 
339     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
340         public List<Value> bufferAndIndicesAsValues() {
341             List<Value> operands = new ArrayList<>(List.of(buffer));
342             operands.addAll(indices);
343             return operands;
344         }
345     }
346 
347     record Node<T>(T value, List<Node<T>> edges) {
348         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
349             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
350             Set<Node<T>> handled = new HashSet<>();
351             Op.Result buffer = null;
352             List<Op.Result> indices = new ArrayList<>();
353             while (!nodeList.isEmpty()) {
354                 Node<T> node = nodeList.removeFirst();
355                 handled.add(node);
356                 if (node.value instanceof Op.Result res &&
357                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
358                     buffer = res;
359                     // idx location differs between ArrayAccessOp and ArrayLengthOp
360                     indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
361                             ? resultFromOperandN(res.op(), 1)
362                             : resultFromOperandN(res.op(), 0)
363                     );
364                 }
365                 if (!node.edges().isEmpty()) {
366                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
367                     if (!handled.contains(next)) {
368                         nodeList.add(next);
369                     }
370                 }
371             }
372             if (buffer != null) {
373                 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
374                 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
375                 return new ArrayAccessInfo(buffer, bufferName, indices);
376             } else {
377                 return null;
378             }
379         }
380     }
381 
382     public static String hatPtrName(Op op) {
383         return switch (op) {
384             case CoreOp.VarOp varOp -> varOp.varName();
385             case VarLikeOp varLikeOp -> varLikeOp.varName();
386             case null, default -> "";
387         };
388     }
389 }