1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.callgraph.KernelCallGraph;
28 import hat.device.NonMappableIface;
29 import hat.dialect.*;
30 import optkl.OpHelper;
31 import optkl.Trxfmr;
32 import optkl.ifacemapper.MappableIface;
33 import jdk.incubator.code.Op;
34 import jdk.incubator.code.Value;
35 import jdk.incubator.code.dialect.core.CoreOp;
36 import jdk.incubator.code.dialect.core.CoreType;
37 import jdk.incubator.code.dialect.java.*;
38 import optkl.util.ops.VarLikeOp;
39
40 import java.util.*;
41
42 import static hat.phases.HATPhaseUtils.getVectorShape;
43 import static optkl.OpHelper.*;
44 import static optkl.OpHelper.Invoke.invoke;
45
46 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
47
48 @Override
49 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
50 if (Invoke.stream(lookup(), funcOp).anyMatch(invoke ->
51 invoke.returnsArray()
52 && invoke.refIs(MappableIface.class, NonMappableIface.class))) {
53 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
54 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
55
56 return Trxfmr.of(this,funcOp).transform( (blockBuilder, op) -> {
57 var context = blockBuilder.context();
58 switch (op) {
59 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
60 if (invoke.nameMatchesRegex("(add|sub|mul|div)")){
61 var hatVectorBinaryOp = HATPhaseUtils.buildVectorBinaryOp(
62 invoke.varOpFromFirstUseOrThrow().varName(),
63 invoke.name(),// so mul, sub etc
64 getVectorShape(lookup(),invoke.returnType()),
65 blockBuilder.context().getValues(invoke.op().operands())
66 );
67 Op.Result binaryResult = blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp));
68 context.mapValue(invoke.returnResult(), binaryResult);
69 replaced.put(invoke.returnResult(), binaryResult);
70 return blockBuilder;
71 } else if (HATPhaseUtils.isBufferArray(invoke.op()) && invoke.resultFromFirstOperandOrNull() instanceof Op.Result result) { // ensures we can use iop as key for replaced vvv
72 replaced.put(invoke.returnResult(), result);
73 // map buffer VarOp to its corresponding VarLoadOp
74 bufferVarLoads.put((resultFromFirstOperandOrNull(result.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) result.op());
75 return blockBuilder;
76 } else{
77 // we do get here.
78 }
79 }
80 case CoreOp.VarOp varOp -> {
81 if (HATPhaseUtils.isBufferInitialize(varOp) && resultFromFirstOperandOrThrow(varOp) instanceof Op.Result result) {
82 // makes sure we don't process a new int[] for example
83 Op bufferLoad = replaced.get(result).op(); // gets VarLoadOp associated w/ og buffer
84 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
85 return blockBuilder;
86 } else if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
87 var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
88 varOp.varName(),
89 varOp.resultType(),
90 getVectorShape(lookup(),varOp.resultType().valueType()),
91 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
92 );
93 context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
94 return blockBuilder;
95 }else{
96 // we do get here.
97 }
98 }
99 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
100 if ((HATPhaseUtils.isBufferInitialize(varLoadOp)) && resultFromFirstOperandOrThrow(varLoadOp) instanceof Op.Result r) {
101 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
102 Op.Result replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
103 resultFromFirstOperandOrNull((resultFromFirstOperandOrNull(r.op())).op()) :
104 bufferVarLoads.get(replaced.get(r).op()).result();
105 replaced.put(varLoadOp.result(), replacement);
106 } else { // if this is a VarLoadOp loading the buffer
107 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
108 Op.Result res = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
109 context.mapValue(varLoadOp.result(), res);
110 replaced.put(varLoadOp.result(), res);
111 }
112 return blockBuilder;
113 }else{
114 // we do get here
115 }
116 }
117 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
118 if (HATPhaseUtils.isBufferArray(arrayLoadOp) && resultFromFirstOperandOrNull(arrayLoadOp) instanceof Op.Result r) {
119 Op.Result buffer = replaced.getOrDefault(r, r);
120 if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
121 Op vop = opFromFirstOperandOrThrow(buffer.op());
122 String name = switch (vop) {
123 case CoreOp.VarOp varOp -> varOp.varName();
124 case VarLikeOp varLikeOp -> varLikeOp.varName(); // HATMemoryVarOp.HATLocalVarOp && HATMemoryVarOp.HATPrivateVarOp
125 default -> throw new IllegalStateException("Unexpected value: " + vop);
126 };
127 var vectorShape = getVectorShape(lookup(),arrayLoadOp.resultType());
128 HATVectorOp.HATVectorLoadOp vLoadOp = HATPhaseUtils.isLocalSharedOrPrivate(arrayLoadOp)
129 ? new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(
130 name,
131 CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
132 vectorShape,
133 context.getValues(List.of(buffer, arrayLoadOp.operands().getLast())))
134 :new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(
135 name,
136 CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
137 vectorShape,
138 context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
139 );
140 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
141 } else if (OpHelper.firstOperandOrThrow(op).type() instanceof ArrayType arrayType && arrayType.dimensions() == 1) { // we only use the last array load
142 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
143 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
144 var hatPtrLoadOp = new HATPtrOp.HATPtrLoadOp(
145 arrayAccessInfo.bufferName(),
146 arrayLoadOp.resultType(),
147 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
148 context.getValues(operands)
149 );
150 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,hatPtrLoadOp)));
151 }else{
152 // or else
153 }
154 } else {
155 // or else?
156 }
157 return blockBuilder;
158 }
159 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
160 if (HATPhaseUtils.isBufferArray(arrayStoreOp) && resultFromFirstOperandOrThrow(arrayStoreOp) instanceof Op.Result r) {
161 Op.Result buffer = replaced.getOrDefault(r, r);
162 if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
163 Op varOp =
164 HATPhaseUtils.findOpInResultFromFirstOperandsOrThrow(((Op.Result) arrayStoreOp.operands().getLast()).op(), CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class);
165 var name = (varOp instanceof HATVectorOp.HATVectorVarOp)
166 ? ((HATVectorOp.HATVectorVarOp) varOp).varName()
167 : ((CoreOp.VarOp) varOp).varName();
168 var resultType = (varOp instanceof HATVectorOp.HATVectorVarOp)
169 ? (varOp).resultType()
170 : ((CoreOp.VarOp) varOp).resultType();
171 var classType = ((ClassType) ((ArrayType) OpHelper.firstOperandOrThrow(arrayStoreOp).type()).componentType());
172 var vectorShape = getVectorShape(lookup(),classType);
173 HATVectorOp.HATVectorStoreView vStoreOp = HATPhaseUtils.isLocalSharedOrPrivate(arrayStoreOp)
174 ? new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(
175 name,
176 resultType,
177 vectorShape,
178 context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1))))
179 : new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(
180 name,
181 resultType,
182 vectorShape,
183 context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
184 );
185 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
186 } else if (((ArrayType) OpHelper.firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
187 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
188 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
189 operands.add(arrayStoreOp.operands().getLast());
190 HATPtrOp.HATPtrStoreOp ptrLoadOp = new HATPtrOp.HATPtrStoreOp(
191 arrayAccessInfo.bufferName(),
192 arrayStoreOp.resultType(),
193 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
194 context.getValues(operands)
195 );
196 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,ptrLoadOp)));
197 }else{
198 // or else
199 }
200 }else{
201 // or else?
202 }
203 return blockBuilder;
204 }
205 case JavaOp.ArrayLengthOp arrayLengthOp when
206 HATPhaseUtils.isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
207 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
208 var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
209 arrayAccessInfo.bufferName(),
210 arrayLengthOp.resultType(),
211 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
212 context.getValues(List.of(arrayAccessInfo.buffer()))
213 );
214 context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
215 return blockBuilder;
216 }
217 default -> {
218 }
219 }
220 blockBuilder.op(op);
221 return blockBuilder;
222 }).funcOp();
223 }else {
224 return funcOp;
225 }
226 }
227
228 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
229 public List<Value> bufferAndIndicesAsValues() {
230 List<Value> operands = new ArrayList<>(List.of(buffer));
231 operands.addAll(indices);
232 return operands;
233 }
234 };
235
236 record Node<T>(T value, List<Node<T>> edges) {
237 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
238 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
239 Set<Node<T>> handled = new HashSet<>();
240 Op.Result buffer = null;
241 List<Op.Result> indices = new ArrayList<>();
242 while (!nodeList.isEmpty()) {
243 Node<T> node = nodeList.removeFirst();
244 handled.add(node);
245 if (node.value instanceof Op.Result res &&
246 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
247 buffer = res;
248 // idx location differs between ArrayAccessOp and ArrayLengthOp
249 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
250 ? resultFromOperandN(res.op(), 1)
251 : resultFromFirstOperandOrThrow(res.op()));
252 }
253 if (!node.edges().isEmpty()) {
254 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
255 if (!handled.contains(next)) {
256 nodeList.add(next);
257 }
258 }
259 }
260 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
261 String bufferName = hatPtrName(resultFromFirstOperandOrNull(buffer.op()).op());
262 return new ArrayAccessInfo(buffer, bufferName, indices);
263 }
264 }
265
266 public static String hatPtrName(Op op) {
267 return switch (op) {
268 case CoreOp.VarOp varOp -> varOp.varName();
269 case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
270 case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
271 case null, default -> "";
272 };
273 }
274 }