1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.dialect.HATPtrOp;
28 import hat.dialect.HATVectorOp;
29 import jdk.incubator.code.CodeType;
30 import optkl.IfaceValue;
31 import optkl.OpHelper;
32 import optkl.Trxfmr;
33 import jdk.incubator.code.Op;
34 import jdk.incubator.code.Value;
35 import jdk.incubator.code.dialect.core.CoreOp;
36 import jdk.incubator.code.dialect.core.CoreType;
37 import jdk.incubator.code.dialect.java.*;
38 import optkl.VarTable;
39 import optkl.codebuilders.BabylonOpDispatcher;
40 import optkl.util.ops.VarLikeOp;
41
42 import java.lang.invoke.MethodHandles;
43 import java.util.*;
44
45 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
46 import static optkl.IfaceValue.Vector.getVectorShape;
47 import static optkl.OpHelper.Invoke;
48 import static optkl.OpHelper.Invoke.invoke;
49 import static optkl.OpHelper.classTypeToTypeOrThrow;
50 import static optkl.OpHelper.copyLocation;
51 import static optkl.OpHelper.firstOperandOrThrow;
52 import static optkl.OpHelper.opFromFirstOperandOrNull;
53 import static optkl.OpHelper.opFromFirstOperandOrThrow;
54 import static optkl.OpHelper.resultFromFirstOperandOrNull;
55 import static optkl.OpHelper.resultFromFirstOperandOrThrow;
56 import static optkl.OpHelper.resultFromOperandN;
57
58 public record HATArrayViewPhase() implements HATPhase {
59 public static boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
60 if (!op.operands().isEmpty()) {
61 CodeType type = switch(op) {
62 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
63 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
64 default -> OpHelper.firstOperandOrThrow(op).type();
65 };
66 if (type instanceof ArrayType at) {
67 type = at.componentType();
68 }
69 if (type instanceof ClassType ct) {
70 try {
71 return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
72 } catch (ReflectiveOperationException e) {
73 throw new RuntimeException(e);
74 }
75 }
76 }
77 return false;
78 }
79
80 public static boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
81 return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
82 }
83
84
85 static HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String varName, CodeType codeType, String opType, IfaceValue.Vector.Shape vectorShape, List<Value> outputOperands) {
86 return switch (opType) {
87 case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName, codeType, vectorShape, outputOperands);
88 case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName, codeType, vectorShape, outputOperands);
89 case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName, codeType, vectorShape, outputOperands);
90 case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName, codeType, vectorShape, outputOperands);
91 default -> throw new IllegalStateException("Unexpected value: " + opType);
92 };
93 }
94
95 public static boolean isBufferArray(MethodHandles.Lookup lookup, Op op) {
96 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
97 return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
98 }
99
100 public static boolean isBufferInitialize(MethodHandles.Lookup lookup, Op op) {
101 // first check if the return is an array type
102 if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
103 || op instanceof JavaOp.ArrayAccessOp
104 || op.resultType() instanceof ArrayType) return isBufferArray(lookup, op);
105 return false;
106 }
107
108 public static boolean isLocalSharedOrPrivate(Op op) {
109 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
110 return iop != null
111 && (iop.invokeReference().name().toLowerCase().contains("shared")
112 || iop.invokeReference().name().toLowerCase().contains("local")
113 || iop.invokeReference().name().toLowerCase().contains("private")
114 );
115 }
116
117 public static HATVectorOp buildArrayViewVector(Op op, String name, CodeType resultType, IfaceValue.Vector.Shape vectorShape, List<Value> operands) {
118 if (isLocalSharedOrPrivate(op)) {
119 if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
120 return new HATVectorOp.HATVectorLoadOp.HATSharedVectorLoadOp(name, resultType, vectorShape, operands);
121 }
122 return new HATVectorOp.HATVectorStoreView.HATSharedVectorStoreView(name, resultType, vectorShape, operands);
123 } else {
124 if (op instanceof JavaOp.ArrayAccessOp.ArrayLoadOp) {
125 return new HATVectorOp.HATVectorLoadOp.HATPrivateVectorLoadOp(name, resultType, vectorShape, operands);
126 }
127 return new HATVectorOp.HATVectorStoreView.HATPrivateVectorStoreView(name, resultType, vectorShape, operands);
128 }
129 }
130
131
132
133 static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
134 return expressionGraph(value).getInfo(replaced);
135 }
136
137 static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
138 return expressionGraph(new HashMap<>(), value);
139 }
140
141 static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
142 // If value has already been visited return its node
143 if (visited.containsKey(value)) {
144 return visited.get(value);
145 }
146
147 // Find the expression graphs for each operand
148 List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
149
150 // looks like
151 for (Value operand : value.dependsOn()) {
152 if (operand instanceof Op.Result res &&
153 res.op() instanceof JavaOp.InvokeOp iop
154 && iop.invokeReference().name().toLowerCase().contains("arrayview")){ // We need to find a better way
155 continue;
156 }
157 edges.add(expressionGraph(operand));
158 }
159 HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
160 visited.put(value, node);
161 return node;
162 }
163 @Override
164 public CoreOp.FuncOp transform(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp, VarTable varTable) {
165 if (Invoke.stream(lookup, funcOp).noneMatch(
166 invoke -> isBufferArray(lookup, invoke.op())
167 )) return funcOp;
168
169 funcOp = applyArrayView(lookup,funcOp);
170
171 if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
172 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
173 )) funcOp = applyVectorView(lookup,funcOp, varTable);
174 return funcOp;
175 }
176
177 public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp, VarTable varTable) {
178 return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
179 var context = blockBuilder.context();
180 switch (op) {
181 case JavaOp.InvokeOp iOp when invoke(lookup, iOp) instanceof Invoke invoke -> {
182 if (isVectorBinaryOp(invoke.lookup(), invoke)){
183 var hatVectorBinaryOp = buildVectorBinaryOp(
184 invoke.varOpFromFirstUseOrThrow().varName(),
185 invoke.returnType(),
186 invoke.name(),// so mul, sub etc
187 getVectorShape(lookup,invoke.returnType()),
188 blockBuilder.context().getValues(invoke.op().operands())
189 );
190 context.mapValue(invoke.returnResult(), blockBuilder.add(copyLocation(invoke.op(),hatVectorBinaryOp)));
191 return blockBuilder;
192 }
193 }
194 case CoreOp.VarOp varOp -> {
195 if (isVectorOp(lookup,varOp)) {
196 Op.Result op1 = blockBuilder.add(varOp);
197 String functionName = funcOp.funcName();
198 varTable.addIfNeededOrThrow(functionName, op1.op(), VarTable.HATOpAttribute.VECTOR);
199 return blockBuilder;
200 }
201 }
202 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
203 if (isVectorOp(lookup,arrayLoadOp)) {
204 Op.Result buffer = resultFromFirstOperandOrNull(arrayLoadOp);
205 String name = hatPtrName(opFromFirstOperandOrThrow(buffer.op()));
206 var resultType = CoreType.varType(arrayLoadOp.resultType());
207 var vectorShape = getVectorShape(lookup,arrayLoadOp.resultType());
208 List<Value> operands = context.getValues(List.of(buffer, arrayLoadOp.operands().getLast()));
209 HATVectorOp vLoadOp = buildArrayViewVector(arrayLoadOp, name, resultType, vectorShape, operands);
210 context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp,vLoadOp)));
211 }
212 return blockBuilder;
213 }
214 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
215 if (isVectorOp(lookup,arrayStoreOp)) {
216 Op.Result buffer = resultFromFirstOperandOrThrow(arrayStoreOp);
217 Op varOp = opFromFirstOperandOrNull(((Op.Result) arrayStoreOp.operands().getLast()).op());
218 String name = hatPtrName(varOp);
219 var resultType = (varOp).resultType();
220 var vectorShape = getVectorShape(lookup,arrayStoreOp.operands().getLast().type());
221 List<Value> operands = context.getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)));
222 HATVectorOp vStoreOp = buildArrayViewVector(arrayStoreOp, name, resultType, vectorShape, operands);
223 context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp,vStoreOp)));
224 }
225 return blockBuilder;
226 }
227 default -> {
228 }
229 }
230 blockBuilder.add(op);
231 return blockBuilder;
232 }).funcOp();
233 }
234
235 public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup,CoreOp.FuncOp funcOp) {
236 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
237 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
238
239 return Trxfmr.of(lookup,funcOp).transform((blockBuilder, op) -> {
240 var context = blockBuilder.context();
241 switch (op) {
242 case JavaOp.InvokeOp invokeOp when invoke(lookup, invokeOp) instanceof Invoke invoke -> {
243 if (isBufferArray(lookup, invoke.op())) { // ensures we can use iop as key for replaced vvv
244 Op.Result result = invoke.resultFromFirstOperandOrNull();
245 replaced.put(invoke.returnResult(), result);
246 // map buffer VarOp to its corresponding VarLoadOp
247 bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
248 return blockBuilder;
249 }
250 }
251 case CoreOp.VarOp varOp -> {
252 if (isBufferInitialize(lookup, varOp)) {
253 // makes sure we don't process a new int[] for example
254 Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
255 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
256 return blockBuilder;
257 }
258 }
259 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
260 if ((isBufferInitialize(lookup, varLoadOp))) {
261 Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
262 Op.Result replacement;
263 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
264 replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
265 resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
266 bufferVarLoads.get(replaced.get(r).op()).result();
267 } else { // if this is a VarLoadOp loading the buffer
268 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
269 replacement = blockBuilder.add(copyLocation(varLoadOp,newVarLoad));
270 context.mapValue(varLoadOp.result(), replacement);
271 }
272 replaced.put(varLoadOp.result(), replacement);
273 return blockBuilder;
274 }
275 }
276 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
277 if (isBufferArray(lookup, arrayLoadOp)) {
278 Op replacementOp=null;
279 if (isVectorOp(lookup,arrayLoadOp)) {
280 replacementOp = JavaOp.arrayLoadOp(
281 context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
282 context.getValue(arrayLoadOp.operands().getLast()),
283 arrayLoadOp.resultType()
284 );
285 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
286 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
287 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
288 replacementOp = new HATPtrOp.HATPtrLoadOp(
289 arrayAccessInfo.bufferName(),
290 arrayLoadOp.resultType(),
291 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
292 context.getValues(operands)
293 );
294 } else { // we only use the last array load
295 return blockBuilder;
296 }
297 context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp,replacementOp)));
298 return blockBuilder;
299 }
300 }
301 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
302 if (isBufferArray(lookup, arrayStoreOp)) {
303 Op replacementOp;
304 if (isVectorOp(lookup, arrayStoreOp)) {
305 replacementOp = JavaOp.arrayStoreOp(
306 context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
307 context.getValue(arrayStoreOp.operands().get(1)),
308 context.getValue(arrayStoreOp.operands().getLast())
309 );
310 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
311 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
312 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
313 operands.add(arrayStoreOp.operands().getLast());
314 replacementOp = new HATPtrOp.HATPtrStoreOp(
315 arrayAccessInfo.bufferName(),
316 arrayStoreOp.resultType(),
317 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
318 context.getValues(operands)
319 );
320 } else {
321 return blockBuilder;
322 }
323 context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp, replacementOp)));
324 return blockBuilder;
325 }
326 }
327 case JavaOp.ArrayLengthOp arrayLengthOp when
328 isBufferArray(lookup, arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null ->{
329 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
330 var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
331 arrayAccessInfo.bufferName(),
332 arrayLengthOp.resultType(),
333 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
334 context.getValues(List.of(arrayAccessInfo.buffer()))
335 );
336 context.mapValue(arrayLengthOp.result(), blockBuilder.add(copyLocation(arrayLengthOp,hatPtrLengthOp)));
337 return blockBuilder;
338 }
339 default -> {
340 }
341 }
342 blockBuilder.add(op);
343 return blockBuilder;
344 }).funcOp();
345 }
346
347 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
348 public List<Value> bufferAndIndicesAsValues() {
349 List<Value> operands = new ArrayList<>(List.of(buffer));
350 operands.addAll(indices);
351 return operands;
352 }
353 }
354
355 record Node<T>(T value, List<Node<T>> edges) {
356 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
357 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
358 Set<Node<T>> handled = new HashSet<>();
359 Op.Result buffer = null;
360 List<Op.Result> indices = new ArrayList<>();
361 while (!nodeList.isEmpty()) {
362 Node<T> node = nodeList.removeFirst();
363 handled.add(node);
364 if (node.value instanceof Op.Result res &&
365 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
366 buffer = res;
367 // idx location differs between ArrayAccessOp and ArrayLengthOp
368 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
369 ? resultFromOperandN(res.op(), 1)
370 : resultFromOperandN(res.op(), 0)
371 );
372 }
373 if (!node.edges().isEmpty()) {
374 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
375 if (!handled.contains(next)) {
376 nodeList.add(next);
377 }
378 }
379 }
380 if (buffer != null) {
381 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
382 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
383 return new ArrayAccessInfo(buffer, bufferName, indices);
384 } else {
385 return null;
386 }
387 }
388 }
389
390 public static String hatPtrName(Op op) {
391 return switch (op) {
392 case CoreOp.VarOp varOp -> varOp.varName();
393 case VarLikeOp varLikeOp -> varLikeOp.varName();
394 case null, default -> "";
395 };
396 }
397 }