1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.callgraph.KernelCallGraph;
 28 import hat.device.DeviceType;
 29 import hat.dialect.*;
 30 import optkl.util.CallSite;
 31 import optkl.ifacemapper.MappableIface;
 32 import hat.types._V;
 33 import jdk.incubator.code.Block;
 34 import jdk.incubator.code.Op;
 35 import jdk.incubator.code.TypeElement;
 36 import jdk.incubator.code.Value;
 37 import jdk.incubator.code.dialect.core.CoreOp;
 38 import jdk.incubator.code.dialect.core.CoreType;
 39 import jdk.incubator.code.dialect.java.*;
 40 
 41 import java.util.*;
 42 
 43 import static optkl.OpTkl.classTypeToTypeOrThrow;
 44 import static optkl.OpTkl.elements;
 45 import static optkl.OpTkl.isAssignable;
 46 
 47 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
 48 
 49     @Override
 50     public CoreOp.FuncOp apply(CoreOp.FuncOp entry) {
 51 
 52         if (!isArrayView(entry)) return entry;
 53 
 54         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
 55         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
 56 
 57         return entry.transform(entry.funcName(), (bb, op) -> {
 58             switch (op) {
 59                 case JavaOp.InvokeOp invokeOp -> {
 60                     if (isVectorBinaryOperation(invokeOp)) {
 61                         // catching HATVectorBinaryOps not stored in VarOps
 62                         HATVectorOp.HATVectorBinaryOp vBinaryOp = buildVectorBinaryOp(
 63                                 invokeOp.invokeDescriptor().name(),
 64                                 obtainVarNameFromInvoke(invokeOp),
 65                                 invokeOp.resultType(),
 66                                 bb.context().getValues(invokeOp.operands())
 67                         );
 68                         vBinaryOp.setLocation(invokeOp.location());
 69                         Op.Result res = bb.op(vBinaryOp);
 70                         bb.context().mapValue(invokeOp.result(), res);
 71                         replaced.put(invokeOp.result(), res);
 72                         return bb;
 73                     } else if (isBufferArray(invokeOp) &&
 74                             firstOperand(invokeOp) instanceof Op.Result r) { // ensures we can use iop as key for replaced vvv
 75                         replaced.put(invokeOp.result(), r);
 76                         // map buffer VarOp to its corresponding VarLoadOp
 77                         bufferVarLoads.put((firstOperandAsRes(r.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) r.op());
 78                         return bb;
 79                     }
 80                 }
 81                 case CoreOp.VarOp varOp -> {
 82                     if (isBufferInitialize(varOp) &&
 83                             firstOperand(varOp) instanceof Op.Result r) { // makes sure we don't process a new int[] for example
 84                         Op bufferLoad = replaced.get(r).op(); // gets VarLoadOp associated w/ og buffer
 85                         replaced.put(varOp.result(), firstOperandAsRes(bufferLoad)); // gets VarOp associated w/ og buffer
 86                         return bb;
 87                     } else if (isVectorOp(varOp)) {
 88                         List<Value> operands = (varOp.operands().isEmpty()) ? List.of() : List.of(firstOperand(varOp));
 89                         HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(varOp.resultType().valueType());
 90                         HATVectorOp.HATVectorVarOp vVarOp = new HATVectorOp.HATVectorVarOp(
 91                                 varOp.varName(),
 92                                 varOp.resultType(),
 93                                 md.vectorTypeElement(),
 94                                 md.lanes(),
 95                                 bb.context().getValues(operands)
 96                         );
 97                         vVarOp.setLocation(varOp.location());
 98                         Op.Result res = bb.op(vVarOp);
 99                         bb.context().mapValue(varOp.result(), res);
100                         return bb;
101                     }
102                 }
103                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
104                     if ((isBufferInitialize(varLoadOp)) &&
105                             firstOperand(varLoadOp) instanceof Op.Result r) {
106                         if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
107                             Op.Result replacement = (notGlobalVarOp(varLoadOp)) ?
108                                     firstOperandAsRes((firstOperandAsRes(r.op())).op()) :
109                                     bufferVarLoads.get(replaced.get(r).op()).result();
110                             replaced.put(varLoadOp.result(), replacement);
111                         } else { // if this is a VarLoadOp loading the buffer
112                             Value loaded = getValue(bb, replaced.get(r));
113                             CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(loaded);
114                             newVarLoad.setLocation(varLoadOp.location());
115                             Op.Result res = bb.op(newVarLoad);
116                             bb.context().mapValue(varLoadOp.result(), res);
117                             replaced.put(varLoadOp.result(), res);
118                         }
119                         return bb;
120                     }
121                 }
122                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
123                     if (isBufferArray(arrayLoadOp) &&
124                             firstOperand(arrayLoadOp) instanceof Op.Result r) {
125                         Op.Result buffer = replaced.getOrDefault(r, r);
126                         if (isVectorOp(arrayLoadOp)) {
127                             Op vop = (firstOperandAsRes(buffer.op())).op();
128                             String name = switch (vop) {
129                                 case CoreOp.VarOp varOp -> varOp.varName();
130                                 case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
131                                 case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
132                                 default -> throw new IllegalStateException("Unexpected value: " + vop);
133                             };
134                             HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(arrayLoadOp.resultType());
135                             HATVectorOp.HATVectorLoadOp vLoadOp = new HATVectorOp.HATVectorLoadOp(
136                                     name,
137                                     CoreType.varType(((ArrayType) firstOperand(arrayLoadOp).type()).componentType()),
138                                     md.vectorTypeElement(),
139                                     md.lanes(),
140                                     notGlobalVarOp(arrayLoadOp),
141                                     bb.context().getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
142                             );
143                             vLoadOp.setLocation(arrayLoadOp.location());
144                             Op.Result res = bb.op(vLoadOp);
145                             bb.context().mapValue(arrayLoadOp.result(), res);
146                         } else if (((ArrayType) firstOperand(op).type()).dimensions() == 1) { // we only use the last array load
147                             ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
148                             List<Value> operands = new ArrayList<>();
149                             operands.add(info.buffer);
150                             operands.addAll(info.indices);
151                             HATPtrOp.HATPtrLoadOp ptrLoadOp = new HATPtrOp.HATPtrLoadOp(
152                                     arrayLoadOp.resultType(),
153                                     (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
154                                     bb.context().getValues(operands)
155                             );
156                             ptrLoadOp.setLocation(arrayLoadOp.location());
157                             Op.Result res = bb.op(ptrLoadOp);
158                             bb.context().mapValue(arrayLoadOp.result(), res);
159                         }
160                     }
161                     return bb;
162                 }
163                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
164                     if (isBufferArray(arrayStoreOp) &&
165                             firstOperand(arrayStoreOp) instanceof Op.Result r) {
166                         Op.Result buffer = replaced.getOrDefault(r, r);
167                         if (isVectorOp(arrayStoreOp)) {
168                             Op varOp = findVarOpOrHATVarOP(((Op.Result) arrayStoreOp.operands().getLast()).op());
169                             String name = (varOp instanceof HATVectorOp.HATVectorVarOp) ? ((HATVectorOp.HATVectorVarOp) varOp).varName() : ((CoreOp.VarOp) varOp).varName();
170                             TypeElement resultType = (varOp instanceof HATVectorOp.HATVectorVarOp) ? (varOp).resultType() : ((CoreOp.VarOp) varOp).resultType();
171                             ClassType classType = ((ClassType) ((ArrayType) firstOperand(arrayStoreOp).type()).componentType());
172                             HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(classType);
173                             HATVectorOp.HATVectorStoreView vStoreOp = new HATVectorOp.HATVectorStoreView(
174                                     name,
175                                     resultType,
176                                     md.lanes(),
177                                     md.vectorTypeElement(),
178                                     notGlobalVarOp(arrayStoreOp),
179                                     bb.context().getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
180                             );
181                             vStoreOp.setLocation(arrayStoreOp.location());
182                             Op.Result res = bb.op(vStoreOp);
183                             bb.context().mapValue(arrayStoreOp.result(), res);
184                         } else if (((ArrayType) firstOperand(op).type()).dimensions() == 1) { // we only use the last array load
185                             ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
186                             List<Value> operands = new ArrayList<>();
187                             operands.add(info.buffer());
188                             operands.addAll(info.indices);
189                             operands.add(arrayStoreOp.operands().getLast());
190                             HATPtrOp.HATPtrStoreOp ptrLoadOp = new HATPtrOp.HATPtrStoreOp(
191                                     arrayStoreOp.resultType(),
192                                     (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
193                                     bb.context().getValues(operands)
194                             );
195                             ptrLoadOp.setLocation(arrayStoreOp.location());
196                             Op.Result res = bb.op(ptrLoadOp);
197                             bb.context().mapValue(arrayStoreOp.result(), res);
198                         }
199                     }
200                     return bb;
201                 }
202                 case JavaOp.ArrayLengthOp arrayLengthOp -> {
203                     if (isBufferArray(arrayLengthOp) &&
204                             firstOperand(arrayLengthOp) instanceof Op.Result r) {
205                         ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
206                         HATPtrOp.HATPtrLengthOp ptrLengthOp = new HATPtrOp.HATPtrLengthOp(
207                                 arrayLengthOp.resultType(),
208                                 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
209                                 bb.context().getValues(List.of(info.buffer()))
210                         );
211                         ptrLengthOp.setLocation(arrayLengthOp.location());
212                         Op.Result res = bb.op(ptrLengthOp);
213                         bb.context().mapValue(arrayLengthOp.result(), res);
214                         return bb;
215                     }
216                 }
217                 default -> {
218                 }
219             }
220             bb.op(op);
221             return bb;
222         });
223     }
224 
225     record ArrayAccessInfo(Op.Result buffer, List<Op.Result> indices) {};
226 
227     record Node<T>(T value, List<Node<T>> edges) {
228         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
229             List<Node<T>> wl = new ArrayList<>();
230             Set<Node<T>> seen = new HashSet<>();
231             Op.Result buffer = null;
232             List<Op.Result> indices = new ArrayList<>();
233             wl.add(this);
234             while (!wl.isEmpty()) {
235                 Node<T> cur = wl.removeFirst();
236                 seen.add(cur);
237                 if (cur.value instanceof Op.Result res) {
238                     if (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp) {
239                         buffer = res;
240                         indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp ? ((Op.Result) res.op().operands().get(1)) : ((Op.Result) res.op().operands().get(0)));
241                     }
242                 }
243                 if (!cur.edges().isEmpty()) {
244                     Node<T> next = cur.edges().getFirst();
245                     if (!seen.contains(next)) wl.add(next);
246                 }
247             }
248             buffer = replaced.get((Op.Result) firstOperand(buffer.op()));
249             return new ArrayAccessInfo(buffer, indices);
250         }
251     }
252 
253     static ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
254         return expressionGraph(value).getInfo(replaced);
255     }
256 
257     static Node<Value> expressionGraph(Value value) {
258         return expressionGraph(new HashMap<>(), value);
259     }
260 
261     static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) {
262         // If value has already been visited return its node
263         if (visited.containsKey(value)) {
264             return visited.get(value);
265         }
266 
267         // Find the expression graphs for each operand
268         List<Node<Value>> edges = new ArrayList<>();
269         for (Value operand : value.dependsOn()) {
270             if (operand instanceof Op.Result res && res.op() instanceof JavaOp.InvokeOp iop && iop.invokeDescriptor().name().toLowerCase().contains("arrayview")) continue;
271             edges.add(expressionGraph(operand));
272         }
273         Node<Value> node = new Node<>(value, edges);
274         visited.put(value, node);
275         return node;
276     }
277 
278     /*
279      * Helper functions:
280      */
281 
282     private HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String opType, String varName, TypeElement resultType, List<Value> outputOperands) {
283         HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(resultType);
284         return switch (opType) {
285             case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
286             case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
287             case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
288             case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
289             default -> throw new IllegalStateException("Unexpected value: " + opType);
290         };
291     }
292 
293     private boolean isVectorBinaryOperation(JavaOp.InvokeOp invokeOp) {
294         TypeElement typeElement = invokeOp.resultType();
295         boolean isHatVectorType = typeElement.toString().startsWith("hat.buffer.Float");
296         return isHatVectorType
297                 && (invokeOp.invokeDescriptor().name().equalsIgnoreCase("add")
298                 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("sub")
299                 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("mul")
300                 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("div"));
301     }
302 
303     private Op findVarOpOrHATVarOP(Op op) {
304         return searchForOp(op, Set.of(CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class));
305     }
306 
307     public boolean isVectorOp(Op op) {
308         if (op.operands().isEmpty()) return false;
309         TypeElement type = firstOperand(op).type();
310         if (type instanceof ArrayType at) type = at.componentType();
311         if (type instanceof ClassType ct) {
312             try {
313                 return _V.class.isAssignableFrom((Class<?>) ct.resolve(lookup()));
314             } catch (ReflectiveOperationException e) {
315                 throw new RuntimeException(e);
316             }
317         }
318         return false;
319     }
320 
321     public static Op.Result firstOperandAsRes(Op op) {
322         return (firstOperand(op) instanceof Op.Result res) ? res : null;
323     }
324 
325     public static Value firstOperand(Op op) {
326         return op.operands().getFirst();
327     }
328 
329     public static Value getValue(Block.Builder bb, Value value) {
330         return bb.context().getValueOrDefault(value, value);
331     }
332 
333     public boolean isBufferArray(Op op) {
334         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) searchForOp(op, Set.of(JavaOp.InvokeOp.class));
335         return iop.invokeDescriptor().name().toLowerCase().contains("arrayview");
336     }
337 
338     public boolean notGlobalVarOp(Op op) {
339         JavaOp.InvokeOp iop = (JavaOp.InvokeOp) searchForOp(op, Set.of(JavaOp.InvokeOp.class));
340         return iop.invokeDescriptor().name().toLowerCase().contains("local") ||
341                 iop.invokeDescriptor().name().toLowerCase().contains("shared") ||
342                 iop.invokeDescriptor().name().toLowerCase().contains("private");
343     }
344 
345     public Op searchForOp(Op op, Set<Class<?>> opClasses) {
346         while (!(opClasses.contains(op.getClass()))) {
347             if (!op.operands().isEmpty() && firstOperand(op) instanceof Op.Result r) {
348                 op = r.op();
349             } else {
350                 return null;
351             }
352         }
353         return op;
354     }
355 
356     public boolean isBufferInitialize(Op op) {
357         // first check if the return is an array type
358         if (op instanceof CoreOp.VarOp vop) {
359             if (!(vop.varValueType() instanceof ArrayType)) return false;
360         } else if (!(op instanceof JavaOp.ArrayAccessOp)) {
361             if (!(op.resultType() instanceof ArrayType)) return false;
362         }
363 
364         return isBufferArray(op);
365     }
366 
367     public boolean isArrayView(CoreOp.FuncOp entry) {
368         var here = CallSite.of(HATArrayViewPhase.class, "isArrayView");
369         return elements(here, entry).anyMatch((element) -> (
370                 element instanceof JavaOp.InvokeOp iop &&
371                         iop.resultType() instanceof ArrayType &&
372                         iop.invokeDescriptor().refType() instanceof JavaType javaType &&
373                         (isAssignable(lookup(), javaType, MappableIface.class)
374                                 || isAssignable(lookup(), javaType, DeviceType.class))));
375     }
376 
377     public Class<?> typeElementToClass(TypeElement type) {
378         class PrimitiveHolder {
379             static final Map<PrimitiveType, Class<?>> primitiveToClass = Map.of(
380                     JavaType.BYTE, byte.class,
381                     JavaType.SHORT, short.class,
382                     JavaType.INT, int.class,
383                     JavaType.LONG, long.class,
384                     JavaType.FLOAT, float.class,
385                     JavaType.DOUBLE, double.class,
386                     JavaType.CHAR, char.class,
387                     JavaType.BOOLEAN, boolean.class
388             );
389         }
390         try {
391             if (type instanceof PrimitiveType primitiveType) {
392                 return PrimitiveHolder.primitiveToClass.get(primitiveType);
393             } else if (type instanceof ClassType classType) {
394                 return ((Class<?>) classType.resolve(lookup()));
395             } else {
396                 throw new IllegalArgumentException("given type cannot be converted to class");
397             }
398         } catch (ReflectiveOperationException e) {
399             throw new RuntimeException("given type cannot be converted to class");
400         }
401     }
402 
403     private String obtainVarNameFromInvoke(JavaOp.InvokeOp invokeOp) {
404         Op.Result invokeResult = invokeOp.result();
405         if (!invokeResult.uses().isEmpty()) {
406             Op.Result r = invokeResult.uses().stream().toList().getFirst();
407             if (r.op() instanceof CoreOp.VarOp varOp) {
408                 return varOp.varName();
409             }
410         }
411         return invokeOp.externalizeOpName();
412     }
413 }