1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.callgraph.KernelCallGraph;
28 import hat.dialect.*;
29 import optkl.OpHelper;
30 import optkl.Trxfmr;
31 import jdk.incubator.code.Op;
32 import jdk.incubator.code.Value;
33 import jdk.incubator.code.dialect.core.CoreOp;
34 import jdk.incubator.code.dialect.core.CoreType;
35 import jdk.incubator.code.dialect.java.*;
36 import optkl.util.ops.VarLikeOp;
37
38 import java.util.*;
39
40 import static hat.phases.HATPhaseUtils.getVectorShape;
41 import static optkl.OpHelper.*;
42 import static optkl.OpHelper.Invoke.invoke;
43
44 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
45
46 @Override
47 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
48 if (Invoke.stream(lookup(), funcOp).noneMatch(
49 invoke -> HATPhaseUtils.isBufferArray(lookup(), invoke.op())
50 )) return funcOp;
51
52 funcOp = applyArrayView(funcOp);
53
54 if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
55 e -> HATPhaseUtils.isVectorOp(lookup(), ((CoreOp.VarOp) e))
56 )) funcOp = applyVectorView(funcOp);
57 return funcOp;
58 }
59
60 public CoreOp.FuncOp applyVectorView(CoreOp.FuncOp funcOp) {
61 return Trxfmr.of(this,funcOp).transform((blockBuilder, op) -> {
62 var context = blockBuilder.context();
63 switch (op) {
64 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
65 if (HATPhaseUtils.isVectorBinaryOp(invoke.lookup(), invoke)){
66 var hatVectorBinaryOp = HATPhaseUtils.buildVectorBinaryOp(
67 invoke.varOpFromFirstUseOrThrow().varName(),
68 invoke.name(),// so mul, sub etc
69 getVectorShape(lookup(),invoke.returnType()),
70 blockBuilder.context().getValues(invoke.op().operands())
71 );
72 context.mapValue(invoke.returnResult(), blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp)));
73 return blockBuilder;
74 }
75 }
76 case CoreOp.VarOp varOp -> {
77 if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
78 var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
79 varOp.varName(),
80 varOp.resultType(),
81 getVectorShape(lookup(),varOp.resultType().valueType()),
82 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
83 );
84 context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
85 return blockBuilder;
86 }
87 }
88 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
89 if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
90 Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
91 String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
92 var resultType = CoreType.varType(arrayLoadOp.resultType());
93 var vectorShape = getVectorShape(lookup(),arrayLoadOp.resultType());
94 List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
95 HATVectorOp vLoadOp = HATPhaseUtils.buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
96 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
97 }
98 return blockBuilder;
99 }
100 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
101 if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
102 Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
103 Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
104 String name = hatPtrName(varOp);
105 var resultType = (varOp).resultType();
106 var vectorShape = getVectorShape(lookup(),arrayStoreOp.operands().getLast().type());
107 List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
108 HATVectorOp vStoreOp = HATPhaseUtils.buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
109 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
110 }
111 return blockBuilder;
112 }
113 default -> {
114 }
115 }
116 blockBuilder.op(op);
117 return blockBuilder;
118 }).funcOp();
119 }
120
121 public CoreOp.FuncOp applyArrayView(CoreOp.FuncOp funcOp) {
122 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
123 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
124
125 return Trxfmr.of(this,funcOp).transform((blockBuilder, op) -> {
126 var context = blockBuilder.context();
127 switch (op) {
128 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
129 if (HATPhaseUtils.isBufferArray(invoke.lookup(), invoke.op())) { // ensures we can use iop as key for replaced vvv
130 Op.Result result = invoke.resultFromFirstOperandOrNull();
131 replaced.put(invoke.returnResult(), result);
132 // map buffer VarOp to its corresponding VarLoadOp
133 bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
134 return blockBuilder;
135 }
136 }
137 case CoreOp.VarOp varOp -> {
138 if (HATPhaseUtils.isBufferInitialize(lookup(), varOp)) {
139 // makes sure we don't process a new int[] for example
140 Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
141 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
142 return blockBuilder;
143 }
144 }
145 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
146 if ((HATPhaseUtils.isBufferInitialize(lookup(), varLoadOp))) {
147 Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
148 Op.Result replacement;
149 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
150 replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
151 resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
152 bufferVarLoads.get(replaced.get(r).op()).result();
153 } else { // if this is a VarLoadOp loading the buffer
154 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
155 replacement = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
156 context.mapValue(varLoadOp.result(), replacement);
157 }
158 replaced.put(varLoadOp.result(), replacement);
159 return blockBuilder;
160 }
161 }
162 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
163 if (HATPhaseUtils.isBufferArray(lookup(), arrayLoadOp)) {
164 Op replacementOp;
165 if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
166 replacementOp = JavaOp.arrayLoadOp(
167 context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
168 context.getValue(arrayLoadOp.operands().getLast()),
169 arrayLoadOp.resultType()
170 );
171 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
172 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
173 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
174 replacementOp = new HATPtrOp.HATPtrLoadOp(
175 arrayAccessInfo.bufferName(),
176 arrayLoadOp.resultType(),
177 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
178 context.getValues(operands)
179 );
180 } else { // we only use the last array load
181 return blockBuilder;
182 }
183 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,replacementOp)));
184 return blockBuilder;
185 }
186 }
187 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
188 if (HATPhaseUtils.isBufferArray(lookup(), arrayStoreOp)) {
189 Op replacementOp;
190 if (HATPhaseUtils.isVectorOp(lookup(), arrayStoreOp)) {
191 replacementOp = JavaOp.arrayStoreOp(
192 context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
193 context.getValue(arrayStoreOp.operands().get(1)),
194 context.getValue(arrayStoreOp.operands().getLast())
195 );
196 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
197 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
198 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
199 operands.add(arrayStoreOp.operands().getLast());
200 replacementOp = new HATPtrOp.HATPtrStoreOp(
201 arrayAccessInfo.bufferName(),
202 arrayStoreOp.resultType(),
203 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
204 context.getValues(operands)
205 );
206 } else {
207 return blockBuilder;
208 }
209 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp, replacementOp)));
210 return blockBuilder;
211 }
212 }
213 case JavaOp.ArrayLengthOp arrayLengthOp when
214 HATPhaseUtils.isBufferArray(lookup(), arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
215 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
216 var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
217 arrayAccessInfo.bufferName(),
218 arrayLengthOp.resultType(),
219 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
220 context.getValues(List.of(arrayAccessInfo.buffer()))
221 );
222 context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
223 return blockBuilder;
224 }
225 default -> {
226 }
227 }
228 blockBuilder.op(op);
229 return blockBuilder;
230 }).funcOp();
231 }
232
233 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
234 public List<Value> bufferAndIndicesAsValues() {
235 List<Value> operands = new ArrayList<>(List.of(buffer));
236 operands.addAll(indices);
237 return operands;
238 }
239 }
240
241 record Node<T>(T value, List<Node<T>> edges) {
242 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
243 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
244 Set<Node<T>> handled = new HashSet<>();
245 Op.Result buffer = null;
246 List<Op.Result> indices = new ArrayList<>();
247 while (!nodeList.isEmpty()) {
248 Node<T> node = nodeList.removeFirst();
249 handled.add(node);
250 if (node.value instanceof Op.Result res &&
251 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
252 buffer = res;
253 // idx location differs between ArrayAccessOp and ArrayLengthOp
254 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
255 ? resultFromOperandN(res.op(), 1)
256 : resultFromFirstOperandOrThrow(res.op()));
257 }
258 if (!node.edges().isEmpty()) {
259 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
260 if (!handled.contains(next)) {
261 nodeList.add(next);
262 }
263 }
264 }
265 if (buffer != null) {
266 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
267 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
268 return new ArrayAccessInfo(buffer, bufferName, indices);
269 } else {
270 return null;
271 }
272 }
273 }
274
275 public static String hatPtrName(Op op) {
276 return switch (op) {
277 case CoreOp.VarOp varOp -> varOp.varName();
278 case VarLikeOp varLikeOp -> varLikeOp.varName();
279 case null, default -> "";
280 };
281 }
282 }