1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.dialect.HATPtrOp;
 28 import jdk.incubator.code.CodeType;
 29 import jdk.incubator.code.dialect.java.ArrayType;
 30 import jdk.incubator.code.dialect.java.ClassType;
 31 import jdk.incubator.code.dialect.java.JavaOp;
 32 import optkl.IfaceValue;
 33 import optkl.OpHelper;
 34 import optkl.Trxfmr;
 35 import jdk.incubator.code.Op;
 36 import jdk.incubator.code.Value;
 37 import jdk.incubator.code.dialect.core.CoreOp;
 38 import optkl.VarTable;
 39 import optkl.util.ops.VarLikeOp;
 40 
 41 import java.lang.invoke.MethodHandles;
 42 import java.util.ArrayList;
 43 import java.util.HashMap;
 44 import java.util.HashSet;
 45 import java.util.List;
 46 import java.util.Map;
 47 import java.util.Set;
 48 
 49 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
 50 import static optkl.OpHelper.Invoke;
 51 import static optkl.OpHelper.Invoke.invoke;
 52 import static optkl.OpHelper.classTypeToTypeOrThrow;
 53 import static optkl.OpHelper.copyLocation;
 54 import static optkl.OpHelper.firstOperandOrThrow;
 55 import static optkl.OpHelper.opFromFirstOperandOrNull;
 56 import static optkl.OpHelper.opFromFirstOperandOrThrow;
 57 import static optkl.OpHelper.resultFromFirstOperandOrNull;
 58 import static optkl.OpHelper.resultFromFirstOperandOrThrow;
 59 import static optkl.OpHelper.resultFromOperandN;
 60 
 61 public record HATArrayViewPhase() implements HATPhase {
 62     public static boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
 63         if (!op.operands().isEmpty()) {
 64             CodeType type = switch (op) {
 65                 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
 66                 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
 67                 default -> OpHelper.firstOperandOrThrow(op).type();
 68             };
 69             if (type instanceof ArrayType at) {
 70                 type = at.componentType();
 71             }
 72             if (type instanceof ClassType ct) {
 73                 try {
 74                     return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
 75                 } catch (ReflectiveOperationException e) {
 76                     throw new IllegalStateException(e);
 77                 }
 78             }
 79         }
 80         return false;
 81     }
 82 
 83     public static boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
 84         return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
 85     }
 86 
 87     public static boolean isBufferArray(Op op) {
 88         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
 89         return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
 90     }
 91 
 92     public static boolean isBufferInitialize(Op op) {
 93         // first check if the return is an array type
 94         if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
 95                 || op instanceof JavaOp.ArrayAccessOp
 96                 || op.resultType() instanceof ArrayType) return isBufferArray(op);
 97         return false;
 98     }
 99 
100     public static boolean isLocalSharedOrPrivate(Op op) {
101         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
102         return iop != null
103                 && (iop.invokeReference().name().toLowerCase().contains("shared")
104                 || iop.invokeReference().name().toLowerCase().contains("local")
105                 || iop.invokeReference().name().toLowerCase().contains("private")
106         );
107     }
108 
109     static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
110         return expressionGraph(value).getInfo(replaced);
111     }
112 
113     static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
114         return expressionGraph(new HashMap<>(), value);
115     }
116 
117     static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
118         // If value has already been visited return its node
119         if (visited.containsKey(value)) {
120             return visited.get(value);
121         }
122 
123         // Find the expression graphs for each operand
124         List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
125 
126         // looks like
127         for (Value operand : value.dependsOn()) {
128             if (operand instanceof Op.Result res &&
129                     res.op() instanceof JavaOp.InvokeOp iop
130                     && iop.invokeReference().name().toLowerCase().contains("arrayview")) { // We need to find a better way
131                 continue;
132             }
133             edges.add(expressionGraph(operand));
134         }
135         HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
136         visited.put(value, node);
137         return node;
138     }
139 
140     @Override
141     public CoreOp.FuncOp transform(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, VarTable varTable) {
142         if (Invoke.stream(lookup, funcOp).noneMatch(
143                 invoke -> isBufferArray(invoke.op())
144         )) return funcOp;
145 
146         funcOp = applyArrayView(lookup, funcOp);
147 
148         if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
149                 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
150         )) funcOp = applyVectorView(lookup, funcOp, varTable);
151         return funcOp;
152     }
153 
154     public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, VarTable varTable) {
155         return Trxfmr.of(lookup, funcOp).transform((blockBuilder, op) -> {
156             switch (op) {
157                 case JavaOp.InvokeOp iOp when invoke(lookup, iOp) instanceof Invoke invoke && isVectorBinaryOp(invoke.lookup(), invoke) ->
158                         blockBuilder.add(op);
159                 case CoreOp.VarOp varOp when isVectorOp(lookup, varOp) -> {
160                     Op.Result op1 = blockBuilder.add(varOp);
161                     String functionName = funcOp.funcName();
162                     varTable.addIfNeededOrThrow(functionName, op1.op(), VarTable.HATOpAttribute.VECTOR);
163                     return blockBuilder;
164                 }
165                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
166                     if (isVectorOp(lookup, arrayLoadOp)) {
167                         blockBuilder.add(op);
168                     }
169                     return blockBuilder;
170                 }
171                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
172                     if (isVectorOp(lookup, arrayStoreOp)) {
173                         blockBuilder.add(op);
174                     }
175                     return blockBuilder;
176                 }
177                 default -> {
178                 }
179             }
180             blockBuilder.add(op);
181             return blockBuilder;
182         }).funcOp();
183     }
184 
185     public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
186         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
187         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
188 
189         return Trxfmr.of(lookup, funcOp).transform((blockBuilder, op) -> {
190             var context = blockBuilder.context();
191             switch (op) {
192                 case JavaOp.InvokeOp invokeOp when invoke(lookup, invokeOp) instanceof Invoke invoke && isBufferArray(invoke.op()) -> {
193                     Op.Result result = invoke.resultFromFirstOperandOrNull();
194                     replaced.put(invoke.returnResult(), result);
195                     // map buffer VarOp to its corresponding VarLoadOp
196                     bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
197                     return blockBuilder;
198                 }
199                 case CoreOp.VarOp varOp when isBufferInitialize(varOp) -> {
200                     Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
201                     replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
202                     return blockBuilder;
203                 }
204                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp when (isBufferInitialize(varLoadOp)) -> {
205                     Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
206                     Op.Result replacement;
207                     if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
208                         replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
209                                 resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
210                                 bufferVarLoads.get(replaced.get(r).op()).result();
211                     } else { // if this is a VarLoadOp loading the buffer
212                         CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
213                         replacement = blockBuilder.add(copyLocation(varLoadOp, newVarLoad));
214                         context.mapValue(varLoadOp.result(), replacement);
215                     }
216                     replaced.put(varLoadOp.result(), replacement);
217                     return blockBuilder;
218                 }
219                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp when isBufferArray(arrayLoadOp) -> {
220                     Op replacementOp;
221                     if (isVectorOp(lookup, arrayLoadOp)) {
222                         replacementOp = JavaOp.arrayLoadOp(
223                                 context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
224                                 context.getValue(arrayLoadOp.operands().getLast()),
225                                 arrayLoadOp.resultType()
226                         );
227                     } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
228                         var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
229                         var operands = arrayAccessInfo.bufferAndIndicesAsValues();
230                         replacementOp = new HATPtrOp.HATPtrLoadOp(
231                                 arrayAccessInfo.bufferName(),
232                                 arrayLoadOp.resultType(),
233                                 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
234                                 context.getValues(operands)
235                         );
236                     } else { // we only use the last array load
237                         return blockBuilder;
238                     }
239                     context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp, replacementOp)));
240                     return blockBuilder;
241                 }
242                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp when isBufferArray(arrayStoreOp) -> {
243                     Op replacementOp;
244                     if (isVectorOp(lookup, arrayStoreOp)) {
245                         replacementOp = JavaOp.arrayStoreOp(
246                                 context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
247                                 context.getValue(arrayStoreOp.operands().get(1)),
248                                 context.getValue(arrayStoreOp.operands().getLast())
249                         );
250                     } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
251                         var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
252                         var operands = arrayAccessInfo.bufferAndIndicesAsValues();
253                         operands.add(arrayStoreOp.operands().getLast());
254                         replacementOp = new HATPtrOp.HATPtrStoreOp(
255                                 arrayAccessInfo.bufferName(),
256                                 arrayStoreOp.resultType(),
257                                 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
258                                 context.getValues(operands)
259                         );
260                     } else {
261                         return blockBuilder;
262                     }
263                     context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp, replacementOp)));
264                     return blockBuilder;
265                 }
266                 case JavaOp.ArrayLengthOp arrayLengthOp when
267                         isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null -> {
268                     var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
269                     var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
270                             arrayAccessInfo.bufferName(),
271                             arrayLengthOp.resultType(),
272                             (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
273                             context.getValues(List.of(arrayAccessInfo.buffer()))
274                     );
275                     context.mapValue(arrayLengthOp.result(), blockBuilder.add(copyLocation(arrayLengthOp, hatPtrLengthOp)));
276                     return blockBuilder;
277                 }
278                 default -> {
279                 }
280             }
281             blockBuilder.add(op);
282             return blockBuilder;
283         }).funcOp();
284     }
285 
286     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
287         public List<Value> bufferAndIndicesAsValues() {
288             List<Value> operands = new ArrayList<>(List.of(buffer));
289             operands.addAll(indices);
290             return operands;
291         }
292     }
293 
294     record Node<T>(T value, List<Node<T>> edges) {
295         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
296             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
297             Set<Node<T>> handled = new HashSet<>();
298             Op.Result buffer = null;
299             List<Op.Result> indices = new ArrayList<>();
300             while (!nodeList.isEmpty()) {
301                 Node<T> node = nodeList.removeFirst();
302                 handled.add(node);
303                 if (node.value instanceof Op.Result res &&
304                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
305                     buffer = res;
306                     // idx location differs between ArrayAccessOp and ArrayLengthOp
307                     indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
308                             ? resultFromOperandN(res.op(), 1)
309                             : resultFromOperandN(res.op(), 0)
310                     );
311                 }
312                 if (!node.edges().isEmpty()) {
313                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
314                     if (!handled.contains(next)) {
315                         nodeList.add(next);
316                     }
317                 }
318             }
319             if (buffer != null) {
320                 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
321                 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
322                 return new ArrayAccessInfo(buffer, bufferName, indices);
323             } else {
324                 return null;
325             }
326         }
327     }
328 
329     public static String hatPtrName(Op op) {
330         return switch (op) {
331             case CoreOp.VarOp varOp -> varOp.varName();
332             case VarLikeOp varLikeOp -> varLikeOp.varName();
333             case null, default -> "";
334         };
335     }
336 }