1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.phases;
 26 
 27 import hat.callgraph.KernelCallGraph;
 28 import hat.dialect.*;
 29 import optkl.OpHelper;
 30 import optkl.Trxfmr;
 31 import jdk.incubator.code.Op;
 32 import jdk.incubator.code.Value;
 33 import jdk.incubator.code.dialect.core.CoreOp;
 34 import jdk.incubator.code.dialect.core.CoreType;
 35 import jdk.incubator.code.dialect.java.*;
 36 import optkl.util.ops.VarLikeOp;
 37 
 38 import java.util.*;
 39 
 40 import static hat.phases.HATPhaseUtils.getVectorShape;
 41 import static optkl.OpHelper.*;
 42 import static optkl.OpHelper.Invoke.invoke;
 43 
 44 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
 45 
 46     @Override
 47     public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
 48         if (Invoke.stream(lookup(), funcOp).noneMatch(
 49                 invoke -> HATPhaseUtils.isBufferArray(lookup(), invoke.op())
 50         )) return funcOp;
 51 
 52         funcOp = applyArrayView(funcOp);
 53 
 54         if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
 55                 e -> HATPhaseUtils.isVectorOp(lookup(), ((CoreOp.VarOp) e))
 56         )) funcOp = applyVectorView(funcOp);
 57         return funcOp;
 58     }
 59 
 60     public CoreOp.FuncOp applyVectorView(CoreOp.FuncOp funcOp) {
 61         return Trxfmr.of(this,funcOp).transform((blockBuilder, op) -> {
 62             var context = blockBuilder.context();
 63             switch (op) {
 64                 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
 65                     if (HATPhaseUtils.isVectorBinaryOp(invoke.lookup(), invoke)){
 66                         var hatVectorBinaryOp = HATPhaseUtils.buildVectorBinaryOp(
 67                                 invoke.varOpFromFirstUseOrThrow().varName(),
 68                                 invoke.name(),// so mul, sub etc
 69                                 getVectorShape(lookup(),invoke.returnType()),
 70                                 blockBuilder.context().getValues(invoke.op().operands())
 71                         );
 72                         context.mapValue(invoke.returnResult(), blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp)));
 73                         return blockBuilder;
 74                     }
 75                 }
 76                 case CoreOp.VarOp varOp -> {
 77                     if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
 78                         var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
 79                                 varOp.varName(),
 80                                 varOp.resultType(),
 81                                 getVectorShape(lookup(),varOp.resultType().valueType()),
 82                                 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
 83                         );
 84                         context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
 85                         return blockBuilder;
 86                     }
 87                 }
 88                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
 89                     if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
 90                         Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
 91                         String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
 92                         var resultType = CoreType.varType(arrayLoadOp.resultType());
 93                         var vectorShape = getVectorShape(lookup(),arrayLoadOp.resultType());
 94                         List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
 95                         HATVectorOp vLoadOp = HATPhaseUtils.buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
 96                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
 97                     }
 98                     return blockBuilder;
 99                 }
100                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
101                     if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
102                         Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
103                         Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
104                         String name = hatPtrName(varOp);
105                         var resultType = (varOp).resultType();
106                         var vectorShape = getVectorShape(lookup(),arrayStoreOp.operands().getLast().type());
107                         List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
108                         HATVectorOp vStoreOp = HATPhaseUtils.buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
109                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
110                     }
111                     return blockBuilder;
112                 }
113                 default -> {
114                 }
115             }
116             blockBuilder.op(op);
117             return blockBuilder;
118         }).funcOp();
119     }
120 
121     public CoreOp.FuncOp applyArrayView(CoreOp.FuncOp funcOp) {
122         Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
123         Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
124 
125         return Trxfmr.of(this,funcOp).transform((blockBuilder, op) -> {
126             var context = blockBuilder.context();
127             switch (op) {
128                 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
129                     if (HATPhaseUtils.isBufferArray(invoke.lookup(), invoke.op())) { // ensures we can use iop as key for replaced vvv
130                         Op.Result result = invoke.resultFromFirstOperandOrNull();
131                         replaced.put(invoke.returnResult(), result);
132                         // map buffer VarOp to its corresponding VarLoadOp
133                         bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
134                         return blockBuilder;
135                     }
136                 }
137                 case CoreOp.VarOp varOp -> {
138                     if (HATPhaseUtils.isBufferInitialize(lookup(), varOp)) {
139                         // makes sure we don't process a new int[] for example
140                         Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
141                         replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
142                         return blockBuilder;
143                     }
144                 }
145                 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
146                     if ((HATPhaseUtils.isBufferInitialize(lookup(), varLoadOp))) {
147                         Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
148                         Op.Result replacement;
149                         if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
150                             replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
151                                     resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
152                                     bufferVarLoads.get(replaced.get(r).op()).result();
153                         } else { // if this is a VarLoadOp loading the buffer
154                             CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
155                             replacement = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
156                             context.mapValue(varLoadOp.result(), replacement);
157                         }
158                         replaced.put(varLoadOp.result(), replacement);
159                         return blockBuilder;
160                     }
161                 }
162                 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
163                     if (HATPhaseUtils.isBufferArray(lookup(), arrayLoadOp)) {
164                         Op replacementOp;
165                         if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
166                             replacementOp = JavaOp.arrayLoadOp(
167                                     context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
168                                     context.getValue(arrayLoadOp.operands().getLast()),
169                                     arrayLoadOp.resultType()
170                             );
171                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
172                             var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
173                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
174                             replacementOp = new HATPtrOp.HATPtrLoadOp(
175                                     arrayAccessInfo.bufferName(),
176                                     arrayLoadOp.resultType(),
177                                     (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
178                                     context.getValues(operands)
179                             );
180                         } else { // we only use the last array load
181                             return blockBuilder;
182                         }
183                         context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,replacementOp)));
184                         return blockBuilder;
185                     }
186                 }
187                 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
188                     if (HATPhaseUtils.isBufferArray(lookup(), arrayStoreOp)) {
189                         Op replacementOp;
190                         if (HATPhaseUtils.isVectorOp(lookup(), arrayStoreOp)) {
191                             replacementOp = JavaOp.arrayStoreOp(
192                                     context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
193                                     context.getValue(arrayStoreOp.operands().get(1)),
194                                     context.getValue(arrayStoreOp.operands().getLast())
195                             );
196                         } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
197                             var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
198                             var operands = arrayAccessInfo.bufferAndIndicesAsValues();
199                             operands.add(arrayStoreOp.operands().getLast());
200                             replacementOp = new HATPtrOp.HATPtrStoreOp(
201                                     arrayAccessInfo.bufferName(),
202                                     arrayStoreOp.resultType(),
203                                     (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
204                                     context.getValues(operands)
205                             );
206                         } else {
207                             return blockBuilder;
208                         }
209                         context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp, replacementOp)));
210                         return blockBuilder;
211                     }
212                 }
213                 case JavaOp.ArrayLengthOp arrayLengthOp  when
214                         HATPhaseUtils.isBufferArray(lookup(), arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
215                     var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
216                     var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
217                             arrayAccessInfo.bufferName(),
218                             arrayLengthOp.resultType(),
219                             (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
220                             context.getValues(List.of(arrayAccessInfo.buffer()))
221                     );
222                     context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
223                     return blockBuilder;
224                 }
225                 default -> {
226                 }
227             }
228             blockBuilder.op(op);
229             return blockBuilder;
230         }).funcOp();
231     }
232 
233     record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
234         public List<Value> bufferAndIndicesAsValues() {
235             List<Value> operands = new ArrayList<>(List.of(buffer));
236             operands.addAll(indices);
237             return operands;
238         }
239     }
240 
241     record Node<T>(T value, List<Node<T>> edges) {
242         ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
243             List<Node<T>> nodeList = new ArrayList<>(List.of(this));
244             Set<Node<T>> handled = new HashSet<>();
245             Op.Result buffer = null;
246             List<Op.Result> indices = new ArrayList<>();
247             while (!nodeList.isEmpty()) {
248                 Node<T> node = nodeList.removeFirst();
249                 handled.add(node);
250                 if (node.value instanceof Op.Result res &&
251                         (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
252                     buffer = res;
253                     // idx location differs between ArrayAccessOp and ArrayLengthOp
254                     indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
255                             ? resultFromOperandN(res.op(), 1)
256                             : resultFromFirstOperandOrThrow(res.op()));
257                 }
258                 if (!node.edges().isEmpty()) {
259                     Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
260                     if (!handled.contains(next)) {
261                         nodeList.add(next);
262                     }
263                 }
264             }
265             if (buffer != null) {
266                 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
267                 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
268                 return new ArrayAccessInfo(buffer, bufferName, indices);
269             } else {
270                 return null;
271             }
272         }
273     }
274 
275     public static String hatPtrName(Op op) {
276         return switch (op) {
277             case CoreOp.VarOp varOp -> varOp.varName();
278             case VarLikeOp varLikeOp -> varLikeOp.varName();
279             case null, default -> "";
280         };
281     }
282 }