1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.dialect.*;
28 import jdk.incubator.code.CodeType;
29 import optkl.IfaceValue;
30 import optkl.OpHelper;
31 import optkl.Trxfmr;
32 import jdk.incubator.code.Op;
33 import jdk.incubator.code.Value;
34 import jdk.incubator.code.dialect.core.CoreOp;
35 import jdk.incubator.code.dialect.core.CoreType;
36 import jdk.incubator.code.dialect.java.*;
37 import optkl.util.ops.VarLikeOp;
38
39 import java.lang.invoke.MethodHandles;
40 import java.util.*;
41
42 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
43 import static optkl.IfaceValue.Vector.getVectorShape;
44 import static optkl.OpHelper.*;
45 import static optkl.OpHelper.Invoke.invoke;
46
47 public record HATArrayViewPhase() implements HATPhase {
48 static public boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
49 if (!op.operands().isEmpty()) {
50 CodeType type = switch(op) {
51 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
52 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
53 default -> OpHelper.firstOperandOrThrow(op).type();
54 };
55 if (type instanceof ArrayType at) {
56 type = at.componentType();
57 }
58 if (type instanceof ClassType ct) {
59 try {
60 return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
61 } catch (ReflectiveOperationException e) {
62 throw new RuntimeException(e);
63 }
64 }
65 }
66 return false;
67 }
68
69 static public boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
70 return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
71 }
72
73
74 static HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String varName, String opType, IfaceValue.Vector.Shape vectorShape, List<Value> outputOperands) {
75 return switch (opType) {
76 case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName, vectorShape, outputOperands);
77 case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName, vectorShape, outputOperands);
78 case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName, vectorShape, outputOperands);
79 case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName, vectorShape, outputOperands);
80 default -> throw new IllegalStateException("Unexpected value: " + opType);
81 };
82 }
83 static public boolean isBufferArray(MethodHandles.Lookup lookup, Op op) {
84 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
85 return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
86 }
87
88 static public boolean isBufferInitialize(MethodHandles.Lookup lookup, Op op) {
89 // first check if the return is an array type
90 if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
91 || op instanceof JavaOp.ArrayAccessOp
92 || op.resultType() instanceof ArrayType) return isBufferArray(lookup, op);
93 return false;
94 }
95 static public boolean isLocalSharedOrPrivate(Op op) {
96 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
97 return iop != null
98 && (iop.invokeReference().name().toLowerCase().contains("shared")
99 || iop.invokeReference().name().toLowerCase().contains("local")
100 || iop.invokeReference().name().toLowerCase().contains("private")
101 );
102 }
103
104 static public HATVectorOp buildArrayViewVector(Op op, String name, CodeType resultType, IfaceValue.Vector.Shape vectorShape, List<Value> operands) {
105 if (isLocalSharedOrPrivate(op)) {
106 if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
107 return new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(name, resultType, vectorShape, operands);
108 }
109 return new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(name, resultType, vectorShape, operands);
110 } else {
111 if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
112 return new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(name, resultType, vectorShape, operands);
113 }
114 return new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(name, resultType, vectorShape, operands);
115 }
116 }
117
118
119
120 static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
121 return expressionGraph(value).getInfo(replaced);
122 }
123
124 static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
125 return expressionGraph(new HashMap<>(), value);
126 }
127
128 static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
129 // If value has already been visited return its node
130 if (visited.containsKey(value)) {
131 return visited.get(value);
132 }
133
134 // Find the expression graphs for each operand
135 List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
136
137 // looks like
138 for (Value operand : value.dependsOn()) {
139 if (operand instanceof Op.Result res &&
140 res.op() instanceof JavaOp.InvokeOp iop
141 && iop.invokeReference().name().toLowerCase().contains("arrayview")){ // We need to find a better way
142 continue;
143 }
144 edges.add(expressionGraph(operand));
145 }
146 HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
147 visited.put(value, node);
148 return node;
149 }
150 @Override
151 public CoreOp.FuncOp transform(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
152 if (Invoke.stream(lookup, funcOp).noneMatch(
153 invoke -> isBufferArray(lookup, invoke.op())
154 )) return funcOp;
155
156 funcOp = applyArrayView(lookup,funcOp);
157
158 if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
159 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
160 )) funcOp = applyVectorView(lookup,funcOp);
161 return funcOp;
162 }
163
164 public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
165 return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
166 var context = blockBuilder.context();
167 switch (op) {
168 case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof Invoke invoke -> {
169 if (isVectorBinaryOp(invoke.lookup(), invoke)){
170 var hatVectorBinaryOp = buildVectorBinaryOp(
171 invoke.varOpFromFirstUseOrThrow().varName(),
172 invoke.name(),// so mul, sub etc
173 getVectorShape(lookup,invoke.returnType()),
174 blockBuilder.context().getValues(invoke.op().operands())
175 );
176 context.mapValue(invoke.returnResult(), blockBuilder.op(copyLocation(invoke.op(),hatVectorBinaryOp)));
177 return blockBuilder;
178 }
179 }
180 case CoreOp.VarOp varOp -> {
181 if (isVectorOp(lookup,varOp)) {
182 var hatVectorVarOp = new HATVectorOp.HATVectorVarOp(
183 varOp.varName(),
184 varOp.resultType(),
185 getVectorShape(lookup,varOp.resultType().valueType()),
186 context.getValues(OpHelper.firstOperandAsListOrEmpty(varOp))
187 );
188 context.mapValue(varOp.result(), blockBuilder.op(copyLocation(varOp,hatVectorVarOp)));
189 return blockBuilder;
190 }
191 }
192 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
193 if (isVectorOp(lookup,arrayLoadOp)) {
194 Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
195 String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
196 var resultType = CoreType.varType(arrayLoadOp.resultType());
197 var vectorShape = getVectorShape(lookup,arrayLoadOp.resultType());
198 List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
199 HATVectorOp vLoadOp = buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
200 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,vLoadOp)));
201 }
202 return blockBuilder;
203 }
204 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
205 if (isVectorOp(lookup,arrayStoreOp)) {
206 Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
207 Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
208 String name = hatPtrName(varOp);
209 var resultType = (varOp).resultType();
210 var vectorShape = getVectorShape(lookup,arrayStoreOp.operands().getLast().type());
211 List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
212 HATVectorOp vStoreOp = buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
213 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp,vStoreOp)));
214 }
215 return blockBuilder;
216 }
217 default -> {
218 }
219 }
220 blockBuilder.op(op);
221 return blockBuilder;
222 }).funcOp();
223 }
224
225 public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
226 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
227 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
228
229 return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
230 var context = blockBuilder.context();
231 switch (op) {
232 case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof Invoke invoke -> {
233 if (isBufferArray(lookup, invoke.op())) { // ensures we can use iop as key for replaced vvv
234 Op.Result result = invoke.resultFromFirstOperandOrNull();
235 replaced.put(invoke.returnResult(), result);
236 // map buffer VarOp to its corresponding VarLoadOp
237 bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
238 return blockBuilder;
239 }
240 }
241 case CoreOp.VarOp varOp -> {
242 if (isBufferInitialize(lookup, varOp)) {
243 // makes sure we don't process a new int[] for example
244 Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
245 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
246 return blockBuilder;
247 }
248 }
249 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
250 if ((isBufferInitialize(lookup, varLoadOp))) {
251 Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
252 Op.Result replacement;
253 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
254 replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
255 resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
256 bufferVarLoads.get(replaced.get(r).op()).result();
257 } else { // if this is a VarLoadOp loading the buffer
258 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
259 replacement = blockBuilder.op(copyLocation(varLoadOp,newVarLoad));
260 context.mapValue(varLoadOp.result(), replacement);
261 }
262 replaced.put(varLoadOp.result(), replacement);
263 return blockBuilder;
264 }
265 }
266 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
267 if (isBufferArray(lookup, arrayLoadOp)) {
268 Op replacementOp=null;
269 if (isVectorOp(lookup,arrayLoadOp)) {
270 replacementOp = JavaOp.arrayLoadOp(
271 context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
272 context.getValue(arrayLoadOp.operands().getLast()),
273 arrayLoadOp.resultType()
274 );
275 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
276 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
277 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
278 replacementOp = new HATPtrOp.HATPtrLoadOp(
279 arrayAccessInfo.bufferName(),
280 arrayLoadOp.resultType(),
281 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
282 context.getValues(operands)
283 );
284 } else { // we only use the last array load
285 return blockBuilder;
286 }
287 context.mapValue(arrayLoadOp.result(), blockBuilder.op(copyLocation(arrayLoadOp,replacementOp)));
288 return blockBuilder;
289 }
290 }
291 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
292 if (isBufferArray(lookup, arrayStoreOp)) {
293 Op replacementOp;
294 if (isVectorOp(lookup, arrayStoreOp)) {
295 replacementOp = JavaOp.arrayStoreOp(
296 context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
297 context.getValue(arrayStoreOp.operands().get(1)),
298 context.getValue(arrayStoreOp.operands().getLast())
299 );
300 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
301 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
302 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
303 operands.add(arrayStoreOp.operands().getLast());
304 replacementOp = new HATPtrOp.HATPtrStoreOp(
305 arrayAccessInfo.bufferName(),
306 arrayStoreOp.resultType(),
307 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
308 context.getValues(operands)
309 );
310 } else {
311 return blockBuilder;
312 }
313 context.mapValue(arrayStoreOp.result(), blockBuilder.op(copyLocation(arrayStoreOp, replacementOp)));
314 return blockBuilder;
315 }
316 }
317 case JavaOp.ArrayLengthOp arrayLengthOp when
318 isBufferArray(lookup, arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
319 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
320 var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
321 arrayAccessInfo.bufferName(),
322 arrayLengthOp.resultType(),
323 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
324 context.getValues(List.of(arrayAccessInfo.buffer()))
325 );
326 context.mapValue(arrayLengthOp.result(), blockBuilder.op(copyLocation(arrayLengthOp,hatPtrLengthOp)));
327 return blockBuilder;
328 }
329 default -> {
330 }
331 }
332 blockBuilder.op(op);
333 return blockBuilder;
334 }).funcOp();
335 }
336
337 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
338 public List<Value> bufferAndIndicesAsValues() {
339 List<Value> operands = new ArrayList<>(List.of(buffer));
340 operands.addAll(indices);
341 return operands;
342 }
343 }
344
345 record Node<T>(T value, List<Node<T>> edges) {
346 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
347 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
348 Set<Node<T>> handled = new HashSet<>();
349 Op.Result buffer = null;
350 List<Op.Result> indices = new ArrayList<>();
351 while (!nodeList.isEmpty()) {
352 Node<T> node = nodeList.removeFirst();
353 handled.add(node);
354 if (node.value instanceof Op.Result res &&
355 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
356 buffer = res;
357 // idx location differs between ArrayAccessOp and ArrayLengthOp
358 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
359 ? resultFromOperandN(res.op(), 1)
360 : resultFromOperandN(res.op(), 0)
361 );
362 }
363 if (!node.edges().isEmpty()) {
364 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
365 if (!handled.contains(next)) {
366 nodeList.add(next);
367 }
368 }
369 }
370 if (buffer != null) {
371 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
372 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
373 return new ArrayAccessInfo(buffer, bufferName, indices);
374 } else {
375 return null;
376 }
377 }
378 }
379
380 public static String hatPtrName(Op op) {
381 return switch (op) {
382 case CoreOp.VarOp varOp -> varOp.varName();
383 case VarLikeOp varLikeOp -> varLikeOp.varName();
384 case null, default -> "";
385 };
386 }
387 }