1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.callgraph.KernelCallGraph;
28 import hat.device.DeviceType;
29 import hat.dialect.*;
30 import optkl.OpHelper;
31 import optkl.Trxfmr;
32 import optkl.ifacemapper.MappableIface;
33 import jdk.incubator.code.Op;
34 import jdk.incubator.code.Value;
35 import jdk.incubator.code.dialect.core.CoreOp;
36 import jdk.incubator.code.dialect.core.CoreType;
37 import jdk.incubator.code.dialect.java.*;
38 import optkl.util.ops.VarLikeOp;
39
40 import java.util.*;
41
42 import static optkl.OpHelper.*;
43 import static optkl.OpHelper.Invoke.invoke;
44
45 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
46
47 @Override
48 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
49 if (Invoke.stream(lookup(), funcOp).anyMatch(invoke ->
50 invoke.returnsArray()
51 && invoke.refIs(MappableIface.class,DeviceType.class))) {
52 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
53 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
54
55 return Trxfmr.of(this,funcOp).transform( (blockBuilder, op) -> {
56 var context = blockBuilder.context();
57 switch (op) {
58 case JavaOp.InvokeOp $ when invoke(lookup(), $) instanceof Invoke invoke -> {
59 if (invoke.namedIgnoreCase("add","sub","mul","div")) {
60 // catching HATVectorBinaryOps not stored in VarOps
61 var hatVectorBinaryOp = invoke.copyLocationTo(HATPhaseUtils.buildVectorBinaryOp(
62 lookup(),
63 invoke.name(),
64 invoke.varOpFromFirstUseOrThrow().varName(),
65 invoke.returnType(),
66 blockBuilder.context().getValues(invoke.op().operands())
67 ));
68 Op.Result binaryResult = blockBuilder.op(hatVectorBinaryOp);
69 context.mapValue(invoke.returnResult(), binaryResult);
70 replaced.put(invoke.returnResult(), binaryResult);
71 return blockBuilder;
72 } else if (HATPhaseUtils.isBufferArray(invoke.op()) && invoke.resultFromFirstOperandOrNull() instanceof Op.Result result) { // ensures we can use iop as key for replaced vvv
73 replaced.put(invoke.returnResult(), result);
74 // map buffer VarOp to its corresponding VarLoadOp
75 bufferVarLoads.put((resultFromFirstOperandOrNull(result.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) result.op());
76 return blockBuilder;
77 } else{
78 // we do get here.
79 }
80 }
81 case CoreOp.VarOp varOp -> {
82 if (HATPhaseUtils.isBufferInitialize(varOp) && resultFromFirstOperandOrThrow(varOp) instanceof Op.Result result) {
83 // makes sure we don't process a new int[] for example
84 Op bufferLoad = replaced.get(result).op(); // gets VarLoadOp associated w/ og buffer
85 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
86 return blockBuilder;
87 } else if (HATPhaseUtils.isVectorOp(lookup(),varOp)) {
88 var vectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),varOp.resultType().valueType());
89 var hatVectorVarOp = copyLocation(varOp,new HATVectorOp.HATVectorVarOp(
90 varOp.varName(),
91 varOp.resultType(),
92 vectorMetaData.vectorTypeElement(),
93 vectorMetaData.lanes(),
94 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
95 ));
96 context.mapValue(varOp.result(), blockBuilder.op(hatVectorVarOp));
97 return blockBuilder;
98 }else{
99 // we do get here.
100 }
101 }
102 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
103 if ((HATPhaseUtils.isBufferInitialize(varLoadOp)) && resultFromFirstOperandOrThrow(varLoadOp) instanceof Op.Result r) {
104 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
105 Op.Result replacement = (HATPhaseUtils.isLocalSharedOrPrivate(varLoadOp)) ?
106 resultFromFirstOperandOrNull((resultFromFirstOperandOrNull(r.op())).op()) :
107 bufferVarLoads.get(replaced.get(r).op()).result();
108 replaced.put(varLoadOp.result(), replacement);
109 } else { // if this is a VarLoadOp loading the buffer
110 // is this not just bb.op(varLoadOp)?
111 CoreOp.VarAccessOp.VarLoadOp newVarLoad = copyLocation(varLoadOp,
112 CoreOp.VarAccessOp.varLoad(
113 blockBuilder.context().getValue(replaced.get(r)))
114 );
115 Op.Result res = blockBuilder.op(newVarLoad);
116 context.mapValue(varLoadOp.result(), res);
117 replaced.put(varLoadOp.result(), res);
118 }
119 return blockBuilder;
120 }else{
121 // we do get here
122 }
123 }
124 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
125 if (HATPhaseUtils.isBufferArray(arrayLoadOp) && resultFromFirstOperandOrNull(arrayLoadOp) instanceof Op.Result r) {
126 Op.Result buffer = replaced.getOrDefault(r, r);
127 if (HATPhaseUtils.isVectorOp(lookup(),arrayLoadOp)) {
128 Op vop = opFromFirstOperandOrThrow(buffer.op());
129 String name = switch (vop) {
130 case CoreOp.VarOp varOp -> varOp.varName();
131 case VarLikeOp varLikeOp -> varLikeOp.varName(); // HATMemoryVarOp.HATLocalVarOp && HATMemoryVarOp.HATPrivateVarOp
132 default -> throw new IllegalStateException("Unexpected value: " + vop);
133 };
134 var hatVectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),arrayLoadOp.resultType());
135 HATVectorOp.HATVectorLoadOp vLoadOp = copyLocation(arrayLoadOp,new HATVectorOp.HATVectorLoadOp(
136 name,
137 CoreType.varType(((ArrayType) OpHelper.firstOperandOrThrow(arrayLoadOp).type()).componentType()),
138 hatVectorMetaData.vectorTypeElement(), // seems like we might pass the hatVectorMetaData here...?
139 hatVectorMetaData.lanes(),
140 HATPhaseUtils.isLocalSharedOrPrivate(arrayLoadOp),
141 context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
142 ));
143 context.mapValue(arrayLoadOp.result(), blockBuilder.op(vLoadOp));
144 } else if (OpHelper.firstOperandOrThrow(op).type() instanceof ArrayType arrayType && arrayType.dimensions() == 1) { // we only use the last array load
145 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
146 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
147 var hatPtrLoadOp = copyLocation(arrayLoadOp,new HATPtrOp.HATPtrLoadOp(
148 arrayAccessInfo.bufferName(),
149 arrayLoadOp.resultType(),
150 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
151 context.getValues(operands)
152 ));
153 context.mapValue(arrayLoadOp.result(), blockBuilder.op(hatPtrLoadOp));
154 }else{
155 // or else
156 }
157 } else {
158 // or else?
159 }
160 return blockBuilder;
161 }
162 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
163 if (HATPhaseUtils.isBufferArray(arrayStoreOp) && resultFromFirstOperandOrThrow(arrayStoreOp) instanceof Op.Result r) {
164 Op.Result buffer = replaced.getOrDefault(r, r);
165 if (HATPhaseUtils.isVectorOp(lookup(),arrayStoreOp)) {
166 Op varOp =
167 HATPhaseUtils.findOpInResultFromFirstOperandsOrThrow(((Op.Result) arrayStoreOp.operands().getLast()).op(), CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class);
168 var name = (varOp instanceof HATVectorOp.HATVectorVarOp)
169 ? ((HATVectorOp.HATVectorVarOp) varOp).varName()
170 : ((CoreOp.VarOp) varOp).varName();
171 var resultType = (varOp instanceof HATVectorOp.HATVectorVarOp)
172 ? (varOp).resultType()
173 : ((CoreOp.VarOp) varOp).resultType();
174 var classType = ((ClassType) ((ArrayType) OpHelper.firstOperandOrThrow(arrayStoreOp).type()).componentType());
175 var vectorMetaData = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(lookup(),classType);
176 HATVectorOp.HATVectorStoreView vStoreOp = copyLocation(arrayStoreOp,new HATVectorOp.HATVectorStoreView(
177 name,
178 resultType,
179 vectorMetaData.lanes(),
180 vectorMetaData.vectorTypeElement(),
181 HATPhaseUtils.isLocalSharedOrPrivate(arrayStoreOp),
182 context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
183 ));
184 context.mapValue(arrayStoreOp.result(), blockBuilder.op(vStoreOp));
185 } else if (((ArrayType) OpHelper.firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
186 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
187 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
188 operands.add(arrayStoreOp.operands().getLast());
189 HATPtrOp.HATPtrStoreOp ptrLoadOp = copyLocation(arrayStoreOp,new HATPtrOp.HATPtrStoreOp(
190 arrayAccessInfo.bufferName(),
191 arrayStoreOp.resultType(),
192 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
193 context.getValues(operands)
194 ));
195 context.mapValue(arrayStoreOp.result(), blockBuilder.op(ptrLoadOp));
196 }else{
197 // or else
198 }
199 }else{
200 // or else?
201 }
202 return blockBuilder;
203 }
204 case JavaOp.ArrayLengthOp arrayLengthOp when
205 HATPhaseUtils.isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
206 var arrayAccessInfo = HATPhaseUtils.arrayAccessInfo(op.result(), replaced);
207 var hatPtrLengthOp = copyLocation(arrayLengthOp,new HATPtrOp.HATPtrLengthOp(
208 arrayAccessInfo.bufferName(),
209 arrayLengthOp.resultType(),
210 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup(), (ClassType) arrayAccessInfo.buffer().type()),
211 context.getValues(List.of(arrayAccessInfo.buffer()))
212 ));
213 context.mapValue(arrayLengthOp.result(), blockBuilder.op(hatPtrLengthOp));
214 return blockBuilder;
215 }
216 default -> {
217 }
218 }
219 blockBuilder.op(op);
220 return blockBuilder;
221 }).funcOp();
222 }else {
223 return funcOp;
224 }
225 }
226
227 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
228 public List<Value> bufferAndIndicesAsValues() {
229 List<Value> operands = new ArrayList<>(List.of(buffer));
230 operands.addAll(indices);
231 return operands;
232 }
233 };
234
235 record Node<T>(T value, List<Node<T>> edges) {
236 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
237 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
238 Set<Node<T>> handled = new HashSet<>();
239 Op.Result buffer = null;
240 List<Op.Result> indices = new ArrayList<>();
241 while (!nodeList.isEmpty()) {
242 Node<T> node = nodeList.removeFirst();
243 handled.add(node);
244 if (node.value instanceof Op.Result res &&
245 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
246 buffer = res;
247 // idx location differs between ArrayAccessOp and ArrayLengthOp
248 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
249 ? resultFromOperandN(res.op(), 1)
250 : resultFromFirstOperandOrThrow(res.op()));
251 }
252 if (!node.edges().isEmpty()) {
253 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
254 if (!handled.contains(next)) {
255 nodeList.add(next);
256 }
257 }
258 }
259 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
260 String bufferName = hatPtrName(resultFromFirstOperandOrNull(buffer.op()).op());
261 return new ArrayAccessInfo(buffer, bufferName, indices);
262 }
263 }
264
265 public static String hatPtrName(Op op) {
266 return switch (op) {
267 case CoreOp.VarOp varOp -> varOp.varName();
268 case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
269 case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
270 case null, default -> "";
271 };
272 }
273 }