1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.dialect.HATPtrOp;
28 import jdk.incubator.code.CodeType;
29 import jdk.incubator.code.dialect.java.ArrayType;
30 import jdk.incubator.code.dialect.java.ClassType;
31 import jdk.incubator.code.dialect.java.JavaOp;
32 import optkl.IfaceValue;
33 import optkl.OpHelper;
34 import optkl.Trxfmr;
35 import jdk.incubator.code.Op;
36 import jdk.incubator.code.Value;
37 import jdk.incubator.code.dialect.core.CoreOp;
38 import optkl.VarTable;
39 import optkl.util.ops.VarLikeOp;
40
41 import java.lang.invoke.MethodHandles;
42 import java.util.ArrayList;
43 import java.util.HashMap;
44 import java.util.HashSet;
45 import java.util.List;
46 import java.util.Map;
47 import java.util.Set;
48
49 import static hat.phases.HATPhaseUtils.findOpInResultFromFirstOperandsOrNull;
50 import static optkl.OpHelper.Invoke;
51 import static optkl.OpHelper.Invoke.invoke;
52 import static optkl.OpHelper.classTypeToTypeOrThrow;
53 import static optkl.OpHelper.copyLocation;
54 import static optkl.OpHelper.firstOperandOrThrow;
55 import static optkl.OpHelper.opFromFirstOperandOrNull;
56 import static optkl.OpHelper.opFromFirstOperandOrThrow;
57 import static optkl.OpHelper.resultFromFirstOperandOrNull;
58 import static optkl.OpHelper.resultFromFirstOperandOrThrow;
59 import static optkl.OpHelper.resultFromOperandN;
60
61 public record HATArrayViewPhase() implements HATPhase {
62 public static boolean isVectorOp(MethodHandles.Lookup lookup, Op op) {
63 if (!op.operands().isEmpty()) {
64 CodeType type = switch (op) {
65 case JavaOp.ArrayAccessOp.ArrayLoadOp load -> load.resultType();
66 case JavaOp.ArrayAccessOp.ArrayStoreOp store -> store.operands().getLast().type();
67 default -> OpHelper.firstOperandOrThrow(op).type();
68 };
69 if (type instanceof ArrayType at) {
70 type = at.componentType();
71 }
72 if (type instanceof ClassType ct) {
73 try {
74 return IfaceValue.Vector.class.isAssignableFrom((Class<?>) ct.resolve(lookup));
75 } catch (ReflectiveOperationException e) {
76 throw new IllegalStateException(e);
77 }
78 }
79 }
80 return false;
81 }
82
83 public static boolean isVectorBinaryOp(MethodHandles.Lookup lookup, OpHelper.Invoke invoke) {
84 return isVectorOp(lookup, invoke.op()) && invoke.nameMatchesRegex("(add|sub|mul|div)");
85 }
86
87 public static boolean isBufferArray(Op op) {
88 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
89 return iop != null && iop.invokeReference().name().toLowerCase().contains("arrayview"); // we need a better way
90 }
91
92 public static boolean isBufferInitialize(Op op) {
93 // first check if the return is an array type
94 if (op instanceof CoreOp.VarOp vop && vop.varValueType() instanceof ArrayType
95 || op instanceof JavaOp.ArrayAccessOp
96 || op.resultType() instanceof ArrayType) return isBufferArray(op);
97 return false;
98 }
99
100 public static boolean isLocalSharedOrPrivate(Op op) {
101 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
102 return iop != null
103 && (iop.invokeReference().name().toLowerCase().contains("shared")
104 || iop.invokeReference().name().toLowerCase().contains("local")
105 || iop.invokeReference().name().toLowerCase().contains("private")
106 );
107 }
108
109 static HATArrayViewPhase.ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
110 return expressionGraph(value).getInfo(replaced);
111 }
112
113 static HATArrayViewPhase.Node<Value> expressionGraph(Value value) {
114 return expressionGraph(new HashMap<>(), value);
115 }
116
117 static HATArrayViewPhase.Node<Value> expressionGraph(Map<Value, HATArrayViewPhase.Node<Value>> visited, Value value) {
118 // If value has already been visited return its node
119 if (visited.containsKey(value)) {
120 return visited.get(value);
121 }
122
123 // Find the expression graphs for each operand
124 List<HATArrayViewPhase.Node<Value>> edges = new ArrayList<>();
125
126 // looks like
127 for (Value operand : value.dependsOn()) {
128 if (operand instanceof Op.Result res &&
129 res.op() instanceof JavaOp.InvokeOp iop
130 && iop.invokeReference().name().toLowerCase().contains("arrayview")) { // We need to find a better way
131 continue;
132 }
133 edges.add(expressionGraph(operand));
134 }
135 HATArrayViewPhase.Node<Value> node = new HATArrayViewPhase.Node<>(value, edges);
136 visited.put(value, node);
137 return node;
138 }
139
140 @Override
141 public CoreOp.FuncOp transform(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, VarTable varTable) {
142 if (Invoke.stream(lookup, funcOp).noneMatch(
143 invoke -> isBufferArray(invoke.op())
144 )) return funcOp;
145
146 funcOp = applyArrayView(lookup, funcOp);
147
148 if (funcOp.elements().filter(e -> e instanceof CoreOp.VarOp).anyMatch(
149 e -> isVectorOp(lookup, ((CoreOp.VarOp) e))
150 )) funcOp = applyVectorView(lookup, funcOp, varTable);
151 return funcOp;
152 }
153
154 public CoreOp.FuncOp applyVectorView(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, VarTable varTable) {
155 return Trxfmr.of(lookup, funcOp).transform((blockBuilder, op) -> {
156 switch (op) {
157 case JavaOp.InvokeOp iOp when invoke(lookup, iOp) instanceof Invoke invoke && isVectorBinaryOp(invoke.lookup(), invoke) ->
158 blockBuilder.add(op);
159 case CoreOp.VarOp varOp when isVectorOp(lookup, varOp) -> {
160 Op.Result op1 = blockBuilder.add(varOp);
161 String functionName = funcOp.funcName();
162 varTable.addIfNeededOrThrow(functionName, op1.op(), VarTable.HATOpAttribute.VECTOR);
163 return blockBuilder;
164 }
165 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
166 if (isVectorOp(lookup, arrayLoadOp)) {
167 blockBuilder.add(op);
168 }
169 return blockBuilder;
170 }
171 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
172 if (isVectorOp(lookup, arrayStoreOp)) {
173 blockBuilder.add(op);
174 }
175 return blockBuilder;
176 }
177 default -> {
178 }
179 }
180 blockBuilder.add(op);
181 return blockBuilder;
182 }).funcOp();
183 }
184
185 public CoreOp.FuncOp applyArrayView(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
186 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
187 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
188
189 return Trxfmr.of(lookup, funcOp).transform((blockBuilder, op) -> {
190 var context = blockBuilder.context();
191 switch (op) {
192 case JavaOp.InvokeOp invokeOp when invoke(lookup, invokeOp) instanceof Invoke invoke && isBufferArray(invoke.op()) -> {
193 Op.Result result = invoke.resultFromFirstOperandOrNull();
194 replaced.put(invoke.returnResult(), result);
195 // map buffer VarOp to its corresponding VarLoadOp
196 bufferVarLoads.put((opFromFirstOperandOrNull(result.op())), (CoreOp.VarAccessOp.VarLoadOp) result.op());
197 return blockBuilder;
198 }
199 case CoreOp.VarOp varOp when isBufferInitialize(varOp) -> {
200 Op bufferLoad = replaced.get(resultFromFirstOperandOrThrow(varOp)).op(); // gets VarLoadOp associated w/ og buffer
201 replaced.put(varOp.result(), resultFromFirstOperandOrNull(bufferLoad)); // gets VarOp associated w/ og buffer
202 return blockBuilder;
203 }
204 case CoreOp.VarAccessOp.VarLoadOp varLoadOp when (isBufferInitialize(varLoadOp)) -> {
205 Op.Result r = resultFromFirstOperandOrThrow(varLoadOp);
206 Op.Result replacement;
207 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
208 replacement = (isLocalSharedOrPrivate(varLoadOp)) ?
209 resultFromFirstOperandOrNull(opFromFirstOperandOrThrow(r.op())) :
210 bufferVarLoads.get(replaced.get(r).op()).result();
211 } else { // if this is a VarLoadOp loading the buffer
212 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(blockBuilder.context().getValue(replaced.get(r)));
213 replacement = blockBuilder.add(copyLocation(varLoadOp, newVarLoad));
214 context.mapValue(varLoadOp.result(), replacement);
215 }
216 replaced.put(varLoadOp.result(), replacement);
217 return blockBuilder;
218 }
219 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp when isBufferArray(arrayLoadOp) -> {
220 Op replacementOp;
221 if (isVectorOp(lookup, arrayLoadOp)) {
222 replacementOp = JavaOp.arrayLoadOp(
223 context.getValue(replaced.get((Op.Result) arrayLoadOp.operands().getFirst())),
224 context.getValue(arrayLoadOp.operands().getLast()),
225 arrayLoadOp.resultType()
226 );
227 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) {
228 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
229 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
230 replacementOp = new HATPtrOp.HATPtrLoadOp(
231 arrayAccessInfo.bufferName(),
232 arrayLoadOp.resultType(),
233 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
234 context.getValues(operands)
235 );
236 } else { // we only use the last array load
237 return blockBuilder;
238 }
239 context.mapValue(arrayLoadOp.result(), blockBuilder.add(copyLocation(arrayLoadOp, replacementOp)));
240 return blockBuilder;
241 }
242 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp when isBufferArray(arrayStoreOp) -> {
243 Op replacementOp;
244 if (isVectorOp(lookup, arrayStoreOp)) {
245 replacementOp = JavaOp.arrayStoreOp(
246 context.getValue(replaced.get((Op.Result) arrayStoreOp.operands().getFirst())),
247 context.getValue(arrayStoreOp.operands().get(1)),
248 context.getValue(arrayStoreOp.operands().getLast())
249 );
250 } else if (((ArrayType) firstOperandOrThrow(op).type()).dimensions() == 1) { // we only use the last array load
251 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
252 var operands = arrayAccessInfo.bufferAndIndicesAsValues();
253 operands.add(arrayStoreOp.operands().getLast());
254 replacementOp = new HATPtrOp.HATPtrStoreOp(
255 arrayAccessInfo.bufferName(),
256 arrayStoreOp.resultType(),
257 (Class<?>) classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
258 context.getValues(operands)
259 );
260 } else {
261 return blockBuilder;
262 }
263 context.mapValue(arrayStoreOp.result(), blockBuilder.add(copyLocation(arrayStoreOp, replacementOp)));
264 return blockBuilder;
265 }
266 case JavaOp.ArrayLengthOp arrayLengthOp when
267 isBufferArray(arrayLengthOp) && resultFromFirstOperandOrThrow(arrayLengthOp) != null -> {
268 var arrayAccessInfo = arrayAccessInfo(op.result(), replaced);
269 var hatPtrLengthOp = new HATPtrOp.HATPtrLengthOp(
270 arrayAccessInfo.bufferName(),
271 arrayLengthOp.resultType(),
272 (Class<?>) OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) arrayAccessInfo.buffer().type()),
273 context.getValues(List.of(arrayAccessInfo.buffer()))
274 );
275 context.mapValue(arrayLengthOp.result(), blockBuilder.add(copyLocation(arrayLengthOp, hatPtrLengthOp)));
276 return blockBuilder;
277 }
278 default -> {
279 }
280 }
281 blockBuilder.add(op);
282 return blockBuilder;
283 }).funcOp();
284 }
285
286 record ArrayAccessInfo(Op.Result buffer, String bufferName, List<Op.Result> indices) {
287 public List<Value> bufferAndIndicesAsValues() {
288 List<Value> operands = new ArrayList<>(List.of(buffer));
289 operands.addAll(indices);
290 return operands;
291 }
292 }
293
294 record Node<T>(T value, List<Node<T>> edges) {
295 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
296 List<Node<T>> nodeList = new ArrayList<>(List.of(this));
297 Set<Node<T>> handled = new HashSet<>();
298 Op.Result buffer = null;
299 List<Op.Result> indices = new ArrayList<>();
300 while (!nodeList.isEmpty()) {
301 Node<T> node = nodeList.removeFirst();
302 handled.add(node);
303 if (node.value instanceof Op.Result res &&
304 (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp)) {
305 buffer = res;
306 // idx location differs between ArrayAccessOp and ArrayLengthOp
307 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp
308 ? resultFromOperandN(res.op(), 1)
309 : resultFromOperandN(res.op(), 0)
310 );
311 }
312 if (!node.edges().isEmpty()) {
313 Node<T> next = node.edges().getFirst(); // we only traverse through the index-related ops
314 if (!handled.contains(next)) {
315 nodeList.add(next);
316 }
317 }
318 }
319 if (buffer != null) {
320 buffer = replaced.get(resultFromFirstOperandOrNull(buffer.op()));
321 String bufferName = hatPtrName(opFromFirstOperandOrNull(buffer.op()));
322 return new ArrayAccessInfo(buffer, bufferName, indices);
323 } else {
324 return null;
325 }
326 }
327 }
328
329 public static String hatPtrName(Op op) {
330 return switch (op) {
331 case CoreOp.VarOp varOp -> varOp.varName();
332 case VarLikeOp varLikeOp -> varLikeOp.varName();
333 case null, default -> "";
334 };
335 }
336 }