1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.dialect.HATPtrOp;
 28 import hat.dialect.HATVectorOp;
 29 import jdk.incubator.code.CodeType;
 30 import optkl.IfaceValue;
 31 import optkl.OpHelper;
 32 import optkl.Trxfmr;
 33 import jdk.incubator.code.Op;
 34 import jdk.incubator.code.Value;
 35 import jdk.incubator.code.dialect.core.CoreOp;
 36 import jdk.incubator.code.dialect.core.CoreType;
 37 import jdk.incubator.code.dialect.java.*;
 38 import optkl.VarTable;
 39 import optkl.codebuilders.BabylonOpDispatcher;
 40 import optkl.util.ops.VarLikeOp;
 41 
 42 import java.lang.invoke.MethodHandles;
 43 import java.util.*;
 44 
 45 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
 46 import static optkl.IfaceValue.Vector.getVectorShape;
 47 import static optkl.OpHelper.Invoke;
 48 import static optkl.OpHelper.Invoke.invoke;
 49 import static optkl.OpHelper.classTypeToTypeOrThrow;
 50 import static optkl.OpHelper.copyLocation;
 51 import static optkl.OpHelper.firstOperandOrThrow;
 52 import static optkl.OpHelper.opFromFirstOperandOrNull;
 53 import static optkl.OpHelper.opFromFirstOperandOrThrow;
 54 import static optkl.OpHelper.resultFromFirstOperandOrNull;
 55 import static optkl.OpHelper.resultFromFirstOperandOrThrow;
 56 import static optkl.OpHelper.resultFromOperandN;
 57 
 58 public record HATArrayViewPhase() implements HATPhase {
 59     public static boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
 60         if (!op.operands().isEmpty()) {
 61             CodeType type = switch(op) {
 62                 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
 63                 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
 64                 default -> OpHelper.firstOperandOrThrow(op).type();
 65             };
 66             if (type instanceof ArrayType at) {
 67                 type = at.componentType();
 68             }
 69             if (type instanceof ClassType ct) {
 70                 try {
 71                     return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
 72                 } catch (ReflectiveOperationException e) {
 73                     throw new RuntimeException(e);
 74                 }
 75             }
 76         }
 77         return false;
 78     }
 79 
 80     public static boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
 81         return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
 82     }
 83 
 84 
 85     static HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String varName, CodeType codeType, String opType, IfaceValue.Vector.Shape vectorShape, List<Value> outputOperands) {
 86         return switch (opType) {
 87             case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName, codeType, vectorShape, outputOperands);
 88             case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName, codeType, vectorShape, outputOperands);
 89             case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName, codeType, vectorShape, outputOperands);
 90             case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName, codeType, vectorShape, outputOperands);
 91             default -> throw new IllegalStateException("Unexpected value: " + opType);
 92         };
 93     }
 94 
 95     public static boolean isBufferArray(MethodHandles.Lookup lookup, Op op) {
 96         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 97         return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
 98     }
 99 
100     public static boolean isBufferInitialize(MethodHandles.Lookup lookup, Op op) {
101         // first check if the return is an array type
102         if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
103                 || op instanceof JavaOp.ArrayAccessOp
104                 || op.resultType() instanceof ArrayType) return isBufferArray(lookup, op);
105         return false;
106     }
107 
108     public static boolean isLocalSharedOrPrivate(Op op) {
109         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
110         return iop != null
111                 && (iop.invokeReference().name().toLowerCase().contains("shared")
112                 || iop.invokeReference().name().toLowerCase().contains("local")
113                 || iop.invokeReference().name().toLowerCase().contains("private")
114         );
115     }
116 
117     public static HATVectorOp buildArrayViewVector(Op op, String name, CodeType resultType, IfaceValue.Vector.Shape vectorShape, List<Value> operands) {
118         if (isLocalSharedOrPrivate(op)) {
119             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
120                 return new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(name, resultType, vectorShape, operands);
121             }
122             return new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(name, resultType, vectorShape, operands);
123         } else {
124             if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
125                 return new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(name, resultType, vectorShape, operands);
126             }
127             return new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(name, resultType, vectorShape, operands);
128         }
129     }
130 
131 
132 
133     static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
134         return expressionGraph(value).getInfo(replaced);
135     }
136 
137     static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
138         return expressionGraph(new HashMap<>(), value);
139     }
140 
141     static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
142         // If value has already been visited return its node
143         if (visited.containsKey(value)) {
144             return visited.get(value);
145         }
146 
147         // Find the expression graphs for each operand
148         List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
149 
150         // looks like
151         for (Value operand : value.dependsOn()) {
152             if (operand instanceof Op.Result res &&
153                     res.op() instanceof JavaOp.InvokeOp iop
154                     && iop.invokeReference().name().toLowerCase().contains("arrayview")){ // We need to find a better way
155                 continue;
156             }
157             edges.add(expressionGraph(operand));
158         }
159         HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
160         visited.put(value, node);
161         return node;
162     }
163     @Override
164     public CoreOp.FuncOp transform(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp, VarTable varTable) {
165         if (Invoke.stream(lookup, funcOp).noneMatch(
166                 invoke -> isBufferArray(lookup, invoke.op())
167         )) return funcOp;
168 
169         funcOp = applyArrayView(lookup,funcOp);
170 
171         if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
172                 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
173         )) funcOp = applyVectorView(lookup,funcOp, varTable);
174         return funcOp;
175     }
176 
177     public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp, VarTable varTable) {
178         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
179             var context = blockBuilder.context();
180             switch (op) {
181                 case JavaOp.InvokeOp iOp when invoke(lookup, iOp) instanceof Invoke invoke -> {
182                     if (isVectorBinaryOp(invoke.lookup(), invoke)){
183                         var hatVectorBinaryOp = buildVectorBinaryOp(
184                                 invoke.varOpFromFirstUseOrThrow().varName(),
185                                 invoke.returnType(),
186                                 invoke.name(),// so mul, sub etc
187                                 getVectorShape(lookup,invoke.returnType()),
188                                 blockBuilder.context().getValues(invoke.op().operands())
189                         );
190                         context.mapValue(invoke.returnResult(), blockBuilder.add(copyLocation(invoke.op(),hatVectorBinaryOp)));
191                         return blockBuilder;
192                     }
193                 }
194                 case CoreOp.VarOp varOp -> {
195                     if (isVectorOp(lookup,varOp)) {
196                         Op.Result op1 = blockBuilder.add(varOp);
197                         String functionName = funcOp.funcName();
198                         varTable.addIfNeededOrThrow(functionName, op1.op(), VarTable.HATOpAttribute.VECTOR);
199                         return blockBuilder;
200                     }
201                 }
202                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
203                     if (isVectorOp(lookup,arrayLoadOp)) {
204                         Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
205                         String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
206                         var resultType = CoreType.varType(arrayLoadOp.resultType());
207                         var vectorShape = getVectorShape(lookup,arrayLoadOp.resultType());
208                         List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
209                         HATVectorOp vLoadOp = buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
210                         context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp,vLoadOp)));
211                     }
212                     return blockBuilder;
213                 }
214                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
215                     if (isVectorOp(lookup,arrayStoreOp)) {
216                         Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
217                         Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
218                         String name = hatPtrName(varOp);
219                         var resultType = (varOp).resultType();
220                         var vectorShape = getVectorShape(lookup,arrayStoreOp.operands().getLast().type());
221                         List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
222                         HATVectorOp vStoreOp = buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
223                         context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp,vStoreOp)));
224                     }
225                     return blockBuilder;
226                 }
227                 default -> {
228                 }
229             }
230             blockBuilder.add(op);
231             return blockBuilder;
232         }).funcOp();
233     }
234 
235     public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
236         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
237         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
238 
239         return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
240             var context = blockBuilder.context();
241             switch (op) {
242                 case JavaOp.InvokeOp invokeOp when invoke(lookup, invokeOp) instanceof Invoke invoke -> {
243                     if (isBufferArray(lookup, invoke.op())) { // ensures we can use iop as key for replaced vvv
244                         Op.Result result = invoke.resultFromFirstOperandOrNull();
245                         replaced.put(invoke.returnResult(), result);
246                         // map buffer VarOp to its corresponding VarLoadOp
247                         bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
248                         return blockBuilder;
249                     }
250                 }
251                 case CoreOp.VarOp varOp -> {
252                     if (isBufferInitialize(lookup, varOp)) {
253                         // makes sure we don't process a new int[] for example
254                         Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
255                         replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
256                         return blockBuilder;
257                     }
258                 }
259                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
260                     if ((isBufferInitialize(lookup, varLoadOp))) {
261                         Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
262                         Op.Result replacement;
263                         if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
264                             replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
265                                     resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
266                                     bufferVarLoads.get(replaced.get(r).op()).result();
267                         } else { // if this is a VarLoadOp loading the buffer
268                             CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
269                             replacement = blockBuilder.add(copyLocation(varLoadOp,newVarLoad));
270                             context.mapValue(varLoadOp.result(), replacement);
271                         }
272                         replaced.put(varLoadOp.result(), replacement);
273                         return blockBuilder;
274                     }
275                 }
276                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
277                     if (isBufferArray(lookup, arrayLoadOp)) {
278                         Op replacementOp=null;
279                         if (isVectorOp(lookup,arrayLoadOp)) {
280                             replacementOp = JavaOp.arrayLoadOp(
281                                     context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
282                                     context.getValue(arrayLoadOp.operands().getLast()),
283                                     arrayLoadOp.resultType()
284                             );
285                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
286                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
287                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
288                             replacementOp = new HATPtrOp.HATPtrLoadOp(
289                                     arrayAccessInfo.bufferName(),
290                                     arrayLoadOp.resultType(),
291                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
292                                     context.getValues(operands)
293                             );
294                         } else { // we only use the last array load
295                             return blockBuilder;
296                         }
297                         context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp,replacementOp)));
298                         return blockBuilder;
299                     }
300                 }
301                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
302                     if (isBufferArray(lookup, arrayStoreOp)) {
303                         Op replacementOp;
304                         if (isVectorOp(lookup, arrayStoreOp)) {
305                             replacementOp = JavaOp.arrayStoreOp(
306                                     context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
307                                     context.getValue(arrayStoreOp.operands().get(1)),
308                                     context.getValue(arrayStoreOp.operands().getLast())
309                             );
310                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
311                             var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
312                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
313                             operands.add(arrayStoreOp.operands().getLast());
314                             replacementOp = new HATPtrOp.HATPtrStoreOp(
315                                     arrayAccessInfo.bufferName(),
316                                     arrayStoreOp.resultType(),
317                                     (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
318                                     context.getValues(operands)
319                             );
320                         } else {
321                             return blockBuilder;
322                         }
323                         context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp, replacementOp)));
324                         return blockBuilder;
325                     }
326                 }
327                 case JavaOp.ArrayLengthOp arrayLengthOp  when
328                         isBufferArray(lookup, arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
329                     var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
330                     var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
331                             arrayAccessInfo.bufferName(),
332                             arrayLengthOp.resultType(),
333                             (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
334                             context.getValues(List.of(arrayAccessInfo.buffer()))
335                     );
336                     context.mapValue(arrayLengthOp.result(), blockBuilder.add(copyLocation(arrayLengthOp,hatPtrLengthOp)));
337                     return blockBuilder;
338                 }
339                 default -> {
340                 }
341             }
342             blockBuilder.add(op);
343             return blockBuilder;
344         }).funcOp();
345     }
346 
347     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
348         public List<Value> bufferAndIndicesAsValues() {
349             List<Value> operands = new ArrayList<>(List.of(buffer));
350             operands.addAll(indices);
351             return operands;
352         }
353     }
354 
355     record Node<T>(T value, List<Node<T>> edges) {
356         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
357             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
358             Set<Node<T>> handled = new HashSet<>();
359             Op.Result buffer = null;
360             List<Op.Result> indices = new ArrayList<>();
361             while (!nodeList.isEmpty()) {
362                 Node<T> node = nodeList.removeFirst();
363                 handled.add(node);
364                 if (node.value instanceof Op.Result res &&
365                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
366                     buffer = res;
367                     // idx location differs between ArrayAccessOp and ArrayLengthOp
368                     indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
369                             ? resultFromOperandN(res.op(), 1)
370                             : resultFromOperandN(res.op(), 0)
371                     );
372                 }
373                 if (!node.edges().isEmpty()) {
374                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
375                     if (!handled.contains(next)) {
376                         nodeList.add(next);
377                     }
378                 }
379             }
380             if (buffer != null) {
381                 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
382                 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
383                 return new ArrayAccessInfo(buffer, bufferName, indices);
384             } else {
385                 return null;
386             }
387         }
388     }
389 
390     public static String hatPtrName(Op op) {
391         return switch (op) {
392             case CoreOp.VarOp varOp -> varOp.varName();
393             case VarLikeOp varLikeOp -> varLikeOp.varName();
394             case null, default -> "";
395         };
396     }
397 }