1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.callgraph.KernelCallGraph;
 28 import hat.device.DeviceType;
 29 import hat.dialect.*;
 30 import optkl.OpHelper;
 31 import optkl.Trxfmr;
 32 import optkl.ifacemapper.MappableIface;
 33 import jdk.incubator.code.Op;
 34 import jdk.incubator.code.Value;
 35 import jdk.incubator.code.dialect.core.CoreOp;
 36 import jdk.incubator.code.dialect.core.CoreType;
 37 import jdk.incubator.code.dialect.java.*;
 38 import optkl.util.ops.VarLikeOp;
 39 
 40 import java.util.*;
 41 
 42 import static optkl.OpHelper.*;
 43 import static optkl.OpHelper.Invoke.invoke;
 44 
 45 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
 46 
 47     @Override
 48     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
 49         if (Invoke.stream(lookup(), funcOp).anyMatch(invoke ->
 50                             invoke.returnsArray()
 51                         && invoke.refIs(MappableIface.class,DeviceType.class))) {
 52             Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
 53             Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
 54 
 55             return Trxfmr.of(this,funcOp).transform( (blockBuilder, op) -> {
 56                 var context = blockBuilder.context();
 57                 switch (op) {
 58                     case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
 59                         if (invoke.namedIgnoreCase("add","sub","mul","div")) {
 60                             // catching HATVectorBinaryOps not stored in VarOps
 61                             var hatVectorBinaryOp = invoke.copyLocationTo(HATPhaseUtils.buildVectorBinaryOp(
 62                                     lookup(),
 63                                     invoke.name(),
 64                                     invoke.varOpFromFirstUseOrThrow().varName(),
 65                                     invoke.returnType(),
 66                                     blockBuilder.context().getValues(invoke.op().operands())
 67                             ));
 68                             Op.Result binaryResult = blockBuilder.op(hatVectorBinaryOp);
 69                            context.mapValue(invoke.returnResult(), binaryResult);
 70                             replaced.put(invoke.returnResult(), binaryResult);
 71                             return blockBuilder;
 72                         } else if (HATPhaseUtils.isBufferArray(invoke.op()) && invoke.resultFromFirstOperandOrNull() instanceof Op.Result result) { // ensures we can use iop as key for replaced vvv
 73                             replaced.put(invoke.returnResult(), result);
 74                             // map buffer VarOp to its corresponding VarLoadOp
 75                             bufferVarLoads.put((resultFromFirstOperandOrNull(result.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) result.op());
 76                             return blockBuilder;
 77                         } else{
 78                             // we do get here.
 79                         }
 80                     }
 81                     case CoreOp.VarOp varOp -> {
 82                         if (HATPhaseUtils.isBufferInitialize(varOp) && resultFromFirstOperandOrThrow(varOp) instanceof Op.Result result) {
 83                             // makes sure we don't process a new int[] for example
 84                             Op bufferLoad = replaced.get(result).op(); // gets VarLoadOp associated w/ og buffer
 85                             replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
 86                             return blockBuilder;
 87                         } else if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
 88                             var vectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),varOp.resultType().valueType());
 89                             var hatVectorVarOp = copyLocation(varOp,new HATVectorOp.HATVectorVarOp(
 90                                     varOp.varName(),
 91                                     varOp.resultType(),
 92                                     vectorMetaData.vectorTypeElement(),
 93                                     vectorMetaData.lanes(),
 94                                    context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
 95                             ));
 96                             context.mapValue(varOp.result(), blockBuilder.op(hatVectorVarOp));
 97                             return blockBuilder;
 98                         }else{
 99                             // we do get here.
100                         }
101                     }
102                     case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
103                         if ((HATPhaseUtils.isBufferInitialize(varLoadOp)) && resultFromFirstOperandOrThrow(varLoadOp) instanceof Op.Result r) {
104                             if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
105                                 Op.Result replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
106                                         resultFromFirstOperandOrNull((resultFromFirstOperandOrNull(r.op())).op()) :
107                                         bufferVarLoads.get(replaced.get(r).op()).result();
108                                 replaced.put(varLoadOp.result(), replacement);
109                             } else { // if this is a VarLoadOp loading the buffer
110                                 // is this not just bb.op(varLoadOp)?
111                                 CoreOp.VarAccessOp.VarLoadOp newVarLoad = copyLocation(varLoadOp,
112                                         CoreOp.VarAccessOp.varLoad(
113                                                 blockBuilder.context().getValue(replaced.get(r)))
114                                 );
115                                 Op.Result res = blockBuilder.op(newVarLoad);
116                                 context.mapValue(varLoadOp.result(), res);
117                                 replaced.put(varLoadOp.result(), res);
118                             }
119                             return blockBuilder;
120                         }else{
121                            // we do get here
122                         }
123                     }
124                     case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
125                         if (HATPhaseUtils.isBufferArray(arrayLoadOp) && resultFromFirstOperandOrNull(arrayLoadOp) instanceof Op.Result r) {
126                             Op.Result buffer = replaced.getOrDefault(r, r);
127                             if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
128                                 Op vop = opFromFirstOperandOrThrow(buffer.op());
129                                 String name = switch (vop) {
130                                     case CoreOp.VarOp varOp -> varOp.varName();
131                                     case VarLikeOp varLikeOp -> varLikeOp.varName(); // HATMemoryVarOp.HATLocalVarOp &&  HATMemoryVarOp.HATPrivateVarOp
132                                     default -> throw new IllegalStateException("Unexpected value: " + vop);
133                                 };
134                                 var  hatVectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),arrayLoadOp.resultType());
135                                 HATVectorOp.HATVectorLoadOp vLoadOp = copyLocation(arrayLoadOp,new HATVectorOp.HATVectorLoadOp(
136                                         name,
137                                         CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
138                                         hatVectorMetaData.vectorTypeElement(), // seems like we might pass the hatVectorMetaData here...?
139                                         hatVectorMetaData.lanes(),
140                                         HATPhaseUtils.isLocalSharedOrPrivate(arrayLoadOp),
141                                         context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
142                                 ));
143                                 context.mapValue(arrayLoadOp.result(), blockBuilder.op(vLoadOp));
144                             } else if (OpHelper.firstOperandOrThrow(op).type() instanceof ArrayType arrayType && arrayType.dimensions() == 1) { // we only use the last array load
145                                 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
146                                 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
147                                 var hatPtrLoadOp = copyLocation(arrayLoadOp,new HATPtrOp.HATPtrLoadOp(
148                                         arrayAccessInfo.bufferName(),
149                                         arrayLoadOp.resultType(),
150                                         (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
151                                         context.getValues(operands)
152                                 ));
153                                 context.mapValue(arrayLoadOp.result(), blockBuilder.op(hatPtrLoadOp));
154                             }else{
155                                 // or else
156                             }
157                         } else {
158                             // or else?
159                         }
160                         return blockBuilder;
161                     }
162                     case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
163                         if (HATPhaseUtils.isBufferArray(arrayStoreOp) && resultFromFirstOperandOrThrow(arrayStoreOp) instanceof Op.Result r) {
164                             Op.Result buffer = replaced.getOrDefault(r, r);
165                             if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
166                                 Op varOp =
167                                         HATPhaseUtils.findOpInResultFromFirstOperandsOrThrow(((Op.Result) arrayStoreOp.operands().getLast()).op(), CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class);
168                                 var name = (varOp instanceof HATVectorOp.HATVectorVarOp)
169                                         ? ((HATVectorOp.HATVectorVarOp) varOp).varName()
170                                         : ((CoreOp.VarOp) varOp).varName();
171                                 var resultType = (varOp instanceof HATVectorOp.HATVectorVarOp)
172                                         ? (varOp).resultType()
173                                         : ((CoreOp.VarOp) varOp).resultType();
174                                 var classType = ((ClassType) ((ArrayType) OpHelper.firstOperandOrThrow(arrayStoreOp).type()).componentType());
175                                 var vectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),classType);
176                                 HATVectorOp.HATVectorStoreView vStoreOp = copyLocation(arrayStoreOp,new HATVectorOp.HATVectorStoreView(
177                                         name,
178                                         resultType,
179                                         vectorMetaData.lanes(),
180                                         vectorMetaData.vectorTypeElement(),
181                                         HATPhaseUtils.isLocalSharedOrPrivate(arrayStoreOp),
182                                         context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
183                                 ));
184                                 context.mapValue(arrayStoreOp.result(), blockBuilder.op(vStoreOp));
185                             } else if (((ArrayType) OpHelper.firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
186                                 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
187                                 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
188                                 operands.add(arrayStoreOp.operands().getLast());
189                                 HATPtrOp.HATPtrStoreOp ptrLoadOp = copyLocation(arrayStoreOp,new HATPtrOp.HATPtrStoreOp(
190                                         arrayAccessInfo.bufferName(),
191                                         arrayStoreOp.resultType(),
192                                         (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
193                                         context.getValues(operands)
194                                 ));
195                                 context.mapValue(arrayStoreOp.result(), blockBuilder.op(ptrLoadOp));
196                             }else{
197                                 // or else
198                             }
199                         }else{
200                             // or else?
201                         }
202                         return blockBuilder;
203                     }
204                     case JavaOp.ArrayLengthOp arrayLengthOp  when
205                         HATPhaseUtils.isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
206                             var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
207                             var hatPtrLengthOp = copyLocation(arrayLengthOp,new HATPtrOp.HATPtrLengthOp(
208                                     arrayAccessInfo.bufferName(),
209                                     arrayLengthOp.resultType(),
210                                     (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
211                                     context.getValues(List.of(arrayAccessInfo.buffer()))
212                             ));
213                             context.mapValue(arrayLengthOp.result(), blockBuilder.op(hatPtrLengthOp));
214                             return blockBuilder;
215                     }
216                     default -> {
217                     }
218                 }
219                 blockBuilder.op(op);
220                 return blockBuilder;
221             }).funcOp();
222         }else {
223             return funcOp;
224         }
225     }
226 
227     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
228         public List<Value> bufferAndIndicesAsValues() {
229             List<Value> operands = new ArrayList<>(List.of(buffer));
230             operands.addAll(indices);
231             return operands;
232         }
233     };
234 
235     record Node<T>(T value, List<Node<T>> edges) {
236         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
237             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
238             Set<Node<T>> handled = new HashSet<>();
239             Op.Result buffer = null;
240             List<Op.Result> indices = new ArrayList<>();
241             while (!nodeList.isEmpty()) {
242                 Node<T> node = nodeList.removeFirst();
243                 handled.add(node);
244                 if (node.value instanceof Op.Result res &&
245                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
246                         buffer = res;
247                         // idx location differs between ArrayAccessOp and ArrayLengthOp
248                         indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
249                                 ? resultFromOperandN(res.op(), 1)
250                                 : resultFromFirstOperandOrThrow(res.op()));
251                 }
252                 if (!node.edges().isEmpty()) {
253                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
254                     if (!handled.contains(next)) {
255                         nodeList.add(next);
256                     }
257                 }
258             }
259             buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
260             String bufferName = hatPtrName(resultFromFirstOperandOrNull(buffer.op()).op());
261             return new ArrayAccessInfo(buffer, bufferName, indices);
262         }
263     }
264 
265     public static String hatPtrName(Op op) {
266         return switch (op) {
267             case CoreOp.VarOp varOp -> varOp.varName();
268             case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
269             case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
270             case null, default -> "";
271         };
272     }
273 }