1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.callgraph.KernelCallGraph;
28 import hat.device.DeviceType;
29 import hat.dialect.*;
30 import optkl.util.CallSite;
31 import optkl.ifacemapper.MappableIface;
32 import hat.types._V;
33 import jdk.incubator.code.Block;
34 import jdk.incubator.code.Op;
35 import jdk.incubator.code.TypeElement;
36 import jdk.incubator.code.Value;
37 import jdk.incubator.code.dialect.core.CoreOp;
38 import jdk.incubator.code.dialect.core.CoreType;
39 import jdk.incubator.code.dialect.java.*;
40
41 import java.util.*;
42
43 import static optkl.OpTkl.classTypeToTypeOrThrow;
44 import static optkl.OpTkl.elements;
45 import static optkl.OpTkl.isAssignable;
46
47 public record HATArrayViewPhase(KernelCallGraph kernelCallGraph) implements HATPhase {
48
49 @Override
50 public CoreOp.FuncOp apply(CoreOp.FuncOp entry) {
51
52 if (!isArrayView(entry)) return entry;
53
54 Map<Op.Result, Op.Result> replaced = new HashMap<>(); // maps a result to the result it should be replaced by
55 Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();
56
57 return entry.transform(entry.funcName(), (bb, op) -> {
58 switch (op) {
59 case JavaOp.InvokeOp invokeOp -> {
60 if (isVectorBinaryOperation(invokeOp)) {
61 // catching HATVectorBinaryOps not stored in VarOps
62 HATVectorOp.HATVectorBinaryOp vBinaryOp = buildVectorBinaryOp(
63 invokeOp.invokeDescriptor().name(),
64 obtainVarNameFromInvoke(invokeOp),
65 invokeOp.resultType(),
66 bb.context().getValues(invokeOp.operands())
67 );
68 vBinaryOp.setLocation(invokeOp.location());
69 Op.Result res = bb.op(vBinaryOp);
70 bb.context().mapValue(invokeOp.result(), res);
71 replaced.put(invokeOp.result(), res);
72 return bb;
73 } else if (isBufferArray(invokeOp) &&
74 firstOperand(invokeOp) instanceof Op.Result r) { // ensures we can use iop as key for replaced vvv
75 replaced.put(invokeOp.result(), r);
76 // map buffer VarOp to its corresponding VarLoadOp
77 bufferVarLoads.put((firstOperandAsRes(r.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) r.op());
78 return bb;
79 }
80 }
81 case CoreOp.VarOp varOp -> {
82 if (isBufferInitialize(varOp) &&
83 firstOperand(varOp) instanceof Op.Result r) { // makes sure we don't process a new int[] for example
84 Op bufferLoad = replaced.get(r).op(); // gets VarLoadOp associated w/ og buffer
85 replaced.put(varOp.result(), firstOperandAsRes(bufferLoad)); // gets VarOp associated w/ og buffer
86 return bb;
87 } else if (isVectorOp(varOp)) {
88 List<Value> operands = (varOp.operands().isEmpty()) ? List.of() : List.of(firstOperand(varOp));
89 HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(varOp.resultType().valueType());
90 HATVectorOp.HATVectorVarOp vVarOp = new HATVectorOp.HATVectorVarOp(
91 varOp.varName(),
92 varOp.resultType(),
93 md.vectorTypeElement(),
94 md.lanes(),
95 bb.context().getValues(operands)
96 );
97 vVarOp.setLocation(varOp.location());
98 Op.Result res = bb.op(vVarOp);
99 bb.context().mapValue(varOp.result(), res);
100 return bb;
101 }
102 }
103 case CoreOp.VarAccessOp.VarLoadOp varLoadOp -> {
104 if ((isBufferInitialize(varLoadOp)) &&
105 firstOperand(varLoadOp) instanceof Op.Result r) {
106 if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
107 Op.Result replacement = (notGlobalVarOp(varLoadOp)) ?
108 firstOperandAsRes((firstOperandAsRes(r.op())).op()) :
109 bufferVarLoads.get(replaced.get(r).op()).result();
110 replaced.put(varLoadOp.result(), replacement);
111 } else { // if this is a VarLoadOp loading the buffer
112 Value loaded = getValue(bb, replaced.get(r));
113 CoreOp.VarAccessOp.VarLoadOp newVarLoad = CoreOp.VarAccessOp.varLoad(loaded);
114 newVarLoad.setLocation(varLoadOp.location());
115 Op.Result res = bb.op(newVarLoad);
116 bb.context().mapValue(varLoadOp.result(), res);
117 replaced.put(varLoadOp.result(), res);
118 }
119 return bb;
120 }
121 }
122 case JavaOp.ArrayAccessOp.ArrayLoadOp arrayLoadOp -> {
123 if (isBufferArray(arrayLoadOp) &&
124 firstOperand(arrayLoadOp) instanceof Op.Result r) {
125 Op.Result buffer = replaced.getOrDefault(r, r);
126 if (isVectorOp(arrayLoadOp)) {
127 Op vop = (firstOperandAsRes(buffer.op())).op();
128 String name = switch (vop) {
129 case CoreOp.VarOp varOp -> varOp.varName();
130 case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> hatLocalVarOp.varName();
131 case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> hatPrivateVarOp.varName();
132 default -> throw new IllegalStateException("Unexpected value: " + vop);
133 };
134 HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(arrayLoadOp.resultType());
135 HATVectorOp.HATVectorLoadOp vLoadOp = new HATVectorOp.HATVectorLoadOp(
136 name,
137 CoreType.varType(((ArrayType) firstOperand(arrayLoadOp).type()).componentType()),
138 md.vectorTypeElement(),
139 md.lanes(),
140 notGlobalVarOp(arrayLoadOp),
141 bb.context().getValues(List.of(buffer, arrayLoadOp.operands().getLast()))
142 );
143 vLoadOp.setLocation(arrayLoadOp.location());
144 Op.Result res = bb.op(vLoadOp);
145 bb.context().mapValue(arrayLoadOp.result(), res);
146 } else if (((ArrayType) firstOperand(op).type()).dimensions() == 1) { // we only use the last array load
147 ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
148 List<Value> operands = new ArrayList<>();
149 operands.add(info.buffer);
150 operands.addAll(info.indices);
151 HATPtrOp.HATPtrLoadOp ptrLoadOp = new HATPtrOp.HATPtrLoadOp(
152 arrayLoadOp.resultType(),
153 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
154 bb.context().getValues(operands)
155 );
156 ptrLoadOp.setLocation(arrayLoadOp.location());
157 Op.Result res = bb.op(ptrLoadOp);
158 bb.context().mapValue(arrayLoadOp.result(), res);
159 }
160 }
161 return bb;
162 }
163 case JavaOp.ArrayAccessOp.ArrayStoreOp arrayStoreOp -> {
164 if (isBufferArray(arrayStoreOp) &&
165 firstOperand(arrayStoreOp) instanceof Op.Result r) {
166 Op.Result buffer = replaced.getOrDefault(r, r);
167 if (isVectorOp(arrayStoreOp)) {
168 Op varOp = findVarOpOrHATVarOP(((Op.Result) arrayStoreOp.operands().getLast()).op());
169 String name = (varOp instanceof HATVectorOp.HATVectorVarOp) ? ((HATVectorOp.HATVectorVarOp) varOp).varName() : ((CoreOp.VarOp) varOp).varName();
170 TypeElement resultType = (varOp instanceof HATVectorOp.HATVectorVarOp) ? (varOp).resultType() : ((CoreOp.VarOp) varOp).resultType();
171 ClassType classType = ((ClassType) ((ArrayType) firstOperand(arrayStoreOp).type()).componentType());
172 HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(classType);
173 HATVectorOp.HATVectorStoreView vStoreOp = new HATVectorOp.HATVectorStoreView(
174 name,
175 resultType,
176 md.lanes(),
177 md.vectorTypeElement(),
178 notGlobalVarOp(arrayStoreOp),
179 bb.context().getValues(List.of(buffer, arrayStoreOp.operands().getLast(), arrayStoreOp.operands().get(1)))
180 );
181 vStoreOp.setLocation(arrayStoreOp.location());
182 Op.Result res = bb.op(vStoreOp);
183 bb.context().mapValue(arrayStoreOp.result(), res);
184 } else if (((ArrayType) firstOperand(op).type()).dimensions() == 1) { // we only use the last array load
185 ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
186 List<Value> operands = new ArrayList<>();
187 operands.add(info.buffer());
188 operands.addAll(info.indices);
189 operands.add(arrayStoreOp.operands().getLast());
190 HATPtrOp.HATPtrStoreOp ptrLoadOp = new HATPtrOp.HATPtrStoreOp(
191 arrayStoreOp.resultType(),
192 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
193 bb.context().getValues(operands)
194 );
195 ptrLoadOp.setLocation(arrayStoreOp.location());
196 Op.Result res = bb.op(ptrLoadOp);
197 bb.context().mapValue(arrayStoreOp.result(), res);
198 }
199 }
200 return bb;
201 }
202 case JavaOp.ArrayLengthOp arrayLengthOp -> {
203 if (isBufferArray(arrayLengthOp) &&
204 firstOperand(arrayLengthOp) instanceof Op.Result r) {
205 ArrayAccessInfo info = arrayAccessInfo(op.result(), replaced);
206 HATPtrOp.HATPtrLengthOp ptrLengthOp = new HATPtrOp.HATPtrLengthOp(
207 arrayLengthOp.resultType(),
208 (Class<?>) classTypeToTypeOrThrow(lookup(), (ClassType) info.buffer().type()),
209 bb.context().getValues(List.of(info.buffer()))
210 );
211 ptrLengthOp.setLocation(arrayLengthOp.location());
212 Op.Result res = bb.op(ptrLengthOp);
213 bb.context().mapValue(arrayLengthOp.result(), res);
214 return bb;
215 }
216 }
217 default -> {
218 }
219 }
220 bb.op(op);
221 return bb;
222 });
223 }
224
225 record ArrayAccessInfo(Op.Result buffer, List<Op.Result> indices) {};
226
227 record Node<T>(T value, List<Node<T>> edges) {
228 ArrayAccessInfo getInfo(Map<Op.Result, Op.Result> replaced) {
229 List<Node<T>> wl = new ArrayList<>();
230 Set<Node<T>> seen = new HashSet<>();
231 Op.Result buffer = null;
232 List<Op.Result> indices = new ArrayList<>();
233 wl.add(this);
234 while (!wl.isEmpty()) {
235 Node<T> cur = wl.removeFirst();
236 seen.add(cur);
237 if (cur.value instanceof Op.Result res) {
238 if (res.op() instanceof JavaOp.ArrayAccessOp || res.op() instanceof JavaOp.ArrayLengthOp) {
239 buffer = res;
240 indices.addFirst(res.op() instanceof JavaOp.ArrayAccessOp ? ((Op.Result) res.op().operands().get(1)) : ((Op.Result) res.op().operands().get(0)));
241 }
242 }
243 if (!cur.edges().isEmpty()) {
244 Node<T> next = cur.edges().getFirst();
245 if (!seen.contains(next)) wl.add(next);
246 }
247 }
248 buffer = replaced.get((Op.Result) firstOperand(buffer.op()));
249 return new ArrayAccessInfo(buffer, indices);
250 }
251 }
252
253 static ArrayAccessInfo arrayAccessInfo(Value value, Map<Op.Result, Op.Result> replaced) {
254 return expressionGraph(value).getInfo(replaced);
255 }
256
257 static Node<Value> expressionGraph(Value value) {
258 return expressionGraph(new HashMap<>(), value);
259 }
260
261 static Node<Value> expressionGraph(Map<Value, Node<Value>> visited, Value value) {
262 // If value has already been visited return its node
263 if (visited.containsKey(value)) {
264 return visited.get(value);
265 }
266
267 // Find the expression graphs for each operand
268 List<Node<Value>> edges = new ArrayList<>();
269 for (Value operand : value.dependsOn()) {
270 if (operand instanceof Op.Result res && res.op() instanceof JavaOp.InvokeOp iop && iop.invokeDescriptor().name().toLowerCase().contains("arrayview")) continue;
271 edges.add(expressionGraph(operand));
272 }
273 Node<Value> node = new Node<>(value, edges);
274 visited.put(value, node);
275 return node;
276 }
277
278 /*
279 * Helper functions:
280 */
281
282 private HATVectorOp.HATVectorBinaryOp buildVectorBinaryOp(String opType, String varName, TypeElement resultType, List<Value> outputOperands) {
283 HATPhaseUtils.VectorMetaData md = HATPhaseUtils.getVectorTypeInfoWithCodeReflection(resultType);
284 return switch (opType) {
285 case "add" -> new HATVectorOp.HATVectorBinaryOp.HATVectorAddOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
286 case "sub" -> new HATVectorOp.HATVectorBinaryOp.HATVectorSubOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
287 case "mul" -> new HATVectorOp.HATVectorBinaryOp.HATVectorMulOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
288 case "div" -> new HATVectorOp.HATVectorBinaryOp.HATVectorDivOp(varName, resultType, md.vectorTypeElement(), md.lanes(), outputOperands);
289 default -> throw new IllegalStateException("Unexpected value: " + opType);
290 };
291 }
292
293 private boolean isVectorBinaryOperation(JavaOp.InvokeOp invokeOp) {
294 TypeElement typeElement = invokeOp.resultType();
295 boolean isHatVectorType = typeElement.toString().startsWith("hat.buffer.Float");
296 return isHatVectorType
297 && (invokeOp.invokeDescriptor().name().equalsIgnoreCase("add")
298 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("sub")
299 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("mul")
300 || invokeOp.invokeDescriptor().name().equalsIgnoreCase("div"));
301 }
302
303 private Op findVarOpOrHATVarOP(Op op) {
304 return searchForOp(op, Set.of(CoreOp.VarOp.class, HATVectorOp.HATVectorVarOp.class));
305 }
306
307 public boolean isVectorOp(Op op) {
308 if (op.operands().isEmpty()) return false;
309 TypeElement type = firstOperand(op).type();
310 if (type instanceof ArrayType at) type = at.componentType();
311 if (type instanceof ClassType ct) {
312 try {
313 return _V.class.isAssignableFrom((Class<?>) ct.resolve(lookup()));
314 } catch (ReflectiveOperationException e) {
315 throw new RuntimeException(e);
316 }
317 }
318 return false;
319 }
320
321 public static Op.Result firstOperandAsRes(Op op) {
322 return (firstOperand(op) instanceof Op.Result res) ? res : null;
323 }
324
325 public static Value firstOperand(Op op) {
326 return op.operands().getFirst();
327 }
328
329 public static Value getValue(Block.Builder bb, Value value) {
330 return bb.context().getValueOrDefault(value, value);
331 }
332
333 public boolean isBufferArray(Op op) {
334 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) searchForOp(op, Set.of(JavaOp.InvokeOp.class));
335 return iop.invokeDescriptor().name().toLowerCase().contains("arrayview");
336 }
337
338 public boolean notGlobalVarOp(Op op) {
339 JavaOp.InvokeOp iop = (JavaOp.InvokeOp) searchForOp(op, Set.of(JavaOp.InvokeOp.class));
340 return iop.invokeDescriptor().name().toLowerCase().contains("local") ||
341 iop.invokeDescriptor().name().toLowerCase().contains("shared") ||
342 iop.invokeDescriptor().name().toLowerCase().contains("private");
343 }
344
345 public Op searchForOp(Op op, Set<Class<?>> opClasses) {
346 while (!(opClasses.contains(op.getClass()))) {
347 if (!op.operands().isEmpty() && firstOperand(op) instanceof Op.Result r) {
348 op = r.op();
349 } else {
350 return null;
351 }
352 }
353 return op;
354 }
355
356 public boolean isBufferInitialize(Op op) {
357 // first check if the return is an array type
358 if (op instanceof CoreOp.VarOp vop) {
359 if (!(vop.varValueType() instanceof ArrayType)) return false;
360 } else if (!(op instanceof JavaOp.ArrayAccessOp)) {
361 if (!(op.resultType() instanceof ArrayType)) return false;
362 }
363
364 return isBufferArray(op);
365 }
366
367 public boolean isArrayView(CoreOp.FuncOp entry) {
368 var here = CallSite.of(HATArrayViewPhase.class, "isArrayView");
369 return elements(here, entry).anyMatch((element) -> (
370 element instanceof JavaOp.InvokeOp iop &&
371 iop.resultType() instanceof ArrayType &&
372 iop.invokeDescriptor().refType() instanceof JavaType javaType &&
373 (isAssignable(lookup(), javaType, MappableIface.class)
374 || isAssignable(lookup(), javaType, DeviceType.class))));
375 }
376
377 public Class<?> typeElementToClass(TypeElement type) {
378 class PrimitiveHolder {
379 static final Map<PrimitiveType, Class<?>> primitiveToClass = Map.of(
380 JavaType.BYTE, byte.class,
381 JavaType.SHORT, short.class,
382 JavaType.INT, int.class,
383 JavaType.LONG, long.class,
384 JavaType.FLOAT, float.class,
385 JavaType.DOUBLE, double.class,
386 JavaType.CHAR, char.class,
387 JavaType.BOOLEAN, boolean.class
388 );
389 }
390 try {
391 if (type instanceof PrimitiveType primitiveType) {
392 return PrimitiveHolder.primitiveToClass.get(primitiveType);
393 } else if (type instanceof ClassType classType) {
394 return ((Class<?>) classType.resolve(lookup()));
395 } else {
396 throw new IllegalArgumentException("given type cannot be converted to class");
397 }
398 } catch (ReflectiveOperationException e) {
399 throw new RuntimeException("given type cannot be converted to class");
400 }
401 }
402
403 private String obtainVarNameFromInvoke(JavaOp.InvokeOp invokeOp) {
404 Op.Result invokeResult = invokeOp.result();
405 if (!invokeResult.uses().isEmpty()) {
406 Op.Result r = invokeResult.uses().stream().toList().getFirst();
407 if (r.op() instanceof CoreOp.VarOp varOp) {
408 return varOp.varName();
409 }
410 }
411 return invokeOp.externalizeOpName();
412 }
413 }