1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.callgraph.KernelCallGraph;
 28 import hat.device.NonMappableIface;
 29 import hat.dialect.*;
 30 import optkl.OpHelper;
 31 import optkl.Trxfmr;
 32 import optkl.ifacemapper.MappableIface;
 33 import jdk.incubator.code.Op;
 34 import jdk.incubator.code.Value;
 35 import jdk.incubator.code.dialect.core.CoreOp;
 36 import jdk.incubator.code.dialect.core.CoreType;
 37 import jdk.incubator.code.dialect.java.*;
 38 import optkl.util.ops.VarLikeOp;
 39 
 40 import java.util.*;
 41 
 42 import static hat.phases.HATPhaseUtils.getVectorShape;
 43 import static optkl.OpHelper.*;
 44 import static optkl.OpHelper.Invoke.invoke;
 45 
 46 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
 47 
 48     @Override
 49     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
 50         if (Invoke.stream(lookup(), funcOp).anyMatch(invoke ->
 51                             invoke.returnsArray()
 52                         && invoke.refIs(MappableIface.class, NonMappableIface.class))) {
 53             Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
 54             Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
 55 
 56             return Trxfmr.of(this,funcOp).transform( (blockBuilder, op) -> {
 57                 var context = blockBuilder.context();
 58                 switch (op) {
 59                     case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
 60                         if (invoke.nameMatchesRegex("(add|sub|mul|div)")){
 61                             var hatVectorBinaryOp = HATPhaseUtils.buildVectorBinaryOp(
 62                                     invoke.varOpFromFirstUseOrThrow().varName(),
 63                                     invoke.name(),// so mul, sub etc
 64                                     getVectorShape(lookup(),invoke.returnType()),
 65                                     blockBuilder.context().getValues(invoke.op().operands())
 66                             );
 67                             Op.Result binaryResult = blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp));
 68                            context.mapValue(invoke.returnResult(), binaryResult);
 69                             replaced.put(invoke.returnResult(), binaryResult);
 70                             return blockBuilder;
 71                         } else if (HATPhaseUtils.isBufferArray(invoke.op()) && invoke.resultFromFirstOperandOrNull() instanceof Op.Result result) { // ensures we can use iop as key for replaced vvv
 72                             replaced.put(invoke.returnResult(), result);
 73                             // map buffer VarOp to its corresponding VarLoadOp
 74                             bufferVarLoads.put((resultFromFirstOperandOrNull(result.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) result.op());
 75                             return blockBuilder;
 76                         } else{
 77                             // we do get here.
 78                         }
 79                     }
 80                     case CoreOp.VarOp varOp -> {
 81                         if (HATPhaseUtils.isBufferInitialize(varOp) && resultFromFirstOperandOrThrow(varOp) instanceof Op.Result result) {
 82                             // makes sure we don't process a new int[] for example
 83                             Op bufferLoad = replaced.get(result).op(); // gets VarLoadOp associated w/ og buffer
 84                             replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
 85                             return blockBuilder;
 86                         } else if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
 87                             var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
 88                                     varOp.varName(),
 89                                     varOp.resultType(),
 90                                     getVectorShape(lookup(),varOp.resultType().valueType()),
 91                                    context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
 92                             );
 93                             context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
 94                             return blockBuilder;
 95                         }else{
 96                             // we do get here.
 97                         }
 98                     }
 99                     case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
100                         if ((HATPhaseUtils.isBufferInitialize(varLoadOp)) && resultFromFirstOperandOrThrow(varLoadOp) instanceof Op.Result r) {
101                             if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
102                                 Op.Result replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
103                                         resultFromFirstOperandOrNull((resultFromFirstOperandOrNull(r.op())).op()) :
104                                         bufferVarLoads.get(replaced.get(r).op()).result();
105                                 replaced.put(varLoadOp.result(), replacement);
106                             } else { // if this is a VarLoadOp loading the buffer
107                                 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
108                                 Op.Result res = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
109                                 context.mapValue(varLoadOp.result(), res);
110                                 replaced.put(varLoadOp.result(), res);
111                             }
112                             return blockBuilder;
113                         }else{
114                            // we do get here
115                         }
116                     }
117                     case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
118                         if (HATPhaseUtils.isBufferArray(arrayLoadOp) && resultFromFirstOperandOrNull(arrayLoadOp) instanceof Op.Result r) {
119                             Op.Result buffer = replaced.getOrDefault(r, r);
120                             if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
121                                 Op vop = opFromFirstOperandOrThrow(buffer.op());
122                                 String name = switch (vop) {
123                                     case CoreOp.VarOp varOp -> varOp.varName();
124                                     case VarLikeOp varLikeOp -> varLikeOp.varName(); // HATMemoryVarOp.HATLocalVarOp &&  HATMemoryVarOp.HATPrivateVarOp
125                                     default -> throw new IllegalStateException("Unexpected value: " + vop);
126                                 };
127                                 var  vectorShape = getVectorShape(lookup(),arrayLoadOp.resultType());
128                                 HATVectorOp.HATVectorLoadOp vLoadOp = HATPhaseUtils.isLocalSharedOrPrivate(arrayLoadOp)
129                                         ? new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(
130                                                 name,
131                                                 CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
132                                                 vectorShape,
133                                                 context.getValues(List.of(buffer, arrayLoadOp.operands().getLast())))
134                                         :new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(
135                                                 name,
136                                                 CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
137                                                 vectorShape,
138                                                 context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
139                                         );
140                                 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
141                             } else if (OpHelper.firstOperandOrThrow(op).type() instanceof ArrayType arrayType && arrayType.dimensions() == 1) { // we only use the last array load
142                                 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
143                                 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
144                                 var hatPtrLoadOp = new HATPtrOp.HATPtrLoadOp(
145                                         arrayAccessInfo.bufferName(),
146                                         arrayLoadOp.resultType(),
147                                         (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
148                                         context.getValues(operands)
149                                 );
150                                 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,hatPtrLoadOp)));
151                             }else{
152                                 // or else
153                             }
154                         } else {
155                             // or else?
156                         }
157                         return blockBuilder;
158                     }
159                     case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
160                         if (HATPhaseUtils.isBufferArray(arrayStoreOp) && resultFromFirstOperandOrThrow(arrayStoreOp) instanceof Op.Result r) {
161                             Op.Result buffer = replaced.getOrDefault(r, r);
162                             if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
163                                 Op varOp =
164                                         HATPhaseUtils.findOpInResultFromFirstOperandsOrThrow(((Op.Result) arrayStoreOp.operands().getLast()).op(), CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class);
165                                 var name = (varOp instanceof HATVectorOp.HATVectorVarOp)
166                                         ? ((HATVectorOp.HATVectorVarOp) varOp).varName()
167                                         : ((CoreOp.VarOp) varOp).varName();
168                                 var resultType = (varOp instanceof HATVectorOp.HATVectorVarOp)
169                                         ? (varOp).resultType()
170                                         : ((CoreOp.VarOp) varOp).resultType();
171                                 var classType = ((ClassType) ((ArrayType) OpHelper.firstOperandOrThrow(arrayStoreOp).type()).componentType());
172                                 var vectorShape = getVectorShape(lookup(),classType);
173                                 HATVectorOp.HATVectorStoreView vStoreOp = HATPhaseUtils.isLocalSharedOrPrivate(arrayStoreOp)
174                                         ? new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(
175                                                 name,
176                                                 resultType,
177                                                 vectorShape,
178                                                 context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1))))
179                                         : new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(
180                                                 name,
181                                                 resultType,
182                                                 vectorShape,
183                                                 context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
184                                         );
185                                 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
186                             } else if (((ArrayType) OpHelper.firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
187                                 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
188                                 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
189                                 operands.add(arrayStoreOp.operands().getLast());
190                                 HATPtrOp.HATPtrStoreOp ptrLoadOp = new HATPtrOp.HATPtrStoreOp(
191                                         arrayAccessInfo.bufferName(),
192                                         arrayStoreOp.resultType(),
193                                         (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
194                                         context.getValues(operands)
195                                 );
196                                 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,ptrLoadOp)));
197                             }else{
198                                 // or else
199                             }
200                         }else{
201                             // or else?
202                         }
203                         return blockBuilder;
204                     }
205                     case JavaOp.ArrayLengthOp arrayLengthOp  when
206                         HATPhaseUtils.isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
207                             var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
208                             var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
209                                     arrayAccessInfo.bufferName(),
210                                     arrayLengthOp.resultType(),
211                                     (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
212                                     context.getValues(List.of(arrayAccessInfo.buffer()))
213                             );
214                             context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
215                             return blockBuilder;
216                     }
217                     default -> {
218                     }
219                 }
220                 blockBuilder.op(op);
221                 return blockBuilder;
222             }).funcOp();
223         }else {
224             return funcOp;
225         }
226     }
227 
228     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
229         public List<Value> bufferAndIndicesAsValues() {
230             List<Value> operands = new ArrayList<>(List.of(buffer));
231             operands.addAll(indices);
232             return operands;
233         }
234     };
235 
236     record Node<T>(T value, List<Node<T>> edges) {
237         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
238             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
239             Set<Node<T>> handled = new HashSet<>();
240             Op.Result buffer = null;
241             List<Op.Result> indices = new ArrayList<>();
242             while (!nodeList.isEmpty()) {
243                 Node<T> node = nodeList.removeFirst();
244                 handled.add(node);
245                 if (node.value instanceof Op.Result res &&
246                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
247                         buffer = res;
248                         // idx location differs between ArrayAccessOp and ArrayLengthOp
249                         indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
250                                 ? resultFromOperandN(res.op(), 1)
251                                 : resultFromFirstOperandOrThrow(res.op()));
252                 }
253                 if (!node.edges().isEmpty()) {
254                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
255                     if (!handled.contains(next)) {
256                         nodeList.add(next);
257                     }
258                 }
259             }
260             buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
261             String bufferName = hatPtrName(resultFromFirstOperandOrNull(buffer.op()).op());
262             return new ArrayAccessInfo(buffer, bufferName, indices);
263         }
264     }
265 
266     public static String hatPtrName(Op op) {
267         return switch (op) {
268             case CoreOp.VarOp varOp -> varOp.varName();
269             case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
270             case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
271             case null, default -> "";
272         };
273     }
274 }