1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.Accelerator;
28 import hat.NDRange;
29 import hat.dialect.HATPhaseUtils;
30 import hat.dialect.HATVectorSelectLoadOp;
31 import hat.dialect.HATVectorSelectStoreOp;
32 import hat.dialect.HATVectorOp;
33 import hat.optools.OpTk;
34 import hat.types._V;
35 import jdk.incubator.code.CodeElement;
36 import jdk.incubator.code.CopyContext;
37 import jdk.incubator.code.Op;
38 import jdk.incubator.code.Value;
39 import jdk.incubator.code.dialect.core.CoreOp;
40 import jdk.incubator.code.dialect.java.JavaOp;
41 import jdk.incubator.code.dialect.java.JavaType;
42
43 import java.util.List;
44 import java.util.Set;
45 import java.util.stream.Collectors;
46 import java.util.stream.Stream;
47
48 public class HATDialectifyVectorSelectPhase implements HATDialect {
49
50 protected final Accelerator accelerator;
51 @Override public Accelerator accelerator(){
52 return this.accelerator;
53 }
54 public HATDialectifyVectorSelectPhase(Accelerator accelerator) {
55 this.accelerator = accelerator;
56 }
57
58 private boolean isVectorLane(JavaOp.InvokeOp invokeOp) {
59 return isMethod(invokeOp, "x")
60 || isMethod(invokeOp, "y")
61 || isMethod(invokeOp, "z")
62 || isMethod(invokeOp, "w");
63 }
64
65 int getLane(String fieldName) {
66 return switch (fieldName) {
67 case "x" -> 0;
68 case "y" -> 1;
69 case "z" -> 2;
70 case "w" -> 3;
71 default -> -1;
72 };
73 }
74
75 private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) {
76 String typeElement = invokeOp.invokeDescriptor().refType().toString();
77 Set<Class<?>> interfaces;
78 try {
79 Class<?> aClass = Class.forName(typeElement);
80 interfaces = HATPhaseUtils.inspectAllInterfaces(aClass);
81 } catch (ClassNotFoundException _) {
82 return false;
83 }
84 return interfaces.contains(_V.class) && isVectorLane(invokeOp);
85 }
86
87 private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
88 return findNameVector(varLoadOp.operands().get(0));
89 }
90
91 private String findNameVector(Value v) {
92 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
93 return findNameVector(varLoadOp);
94 } else {
95 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp vectorViewOp) {
96 return vectorViewOp.varName();
97 }
98 return null;
99 }
100 }
101
102 private CoreOp.VarOp findVarOp(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
103 return findVarOp(varLoadOp.operands().get(0));
104 }
105
106 private CoreOp.VarOp findVarOp(Value v) {
107 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
108 return findVarOp(varLoadOp);
109 } else {
110 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.VarOp varOp) {
111 return varOp;
112 }
113 return null;
114 }
115 }
116
117 // Code Model Pattern:
118 // %16 : java.type:"hat.buffer.Float4" = var.load %15 @loc="63:28";
119 // %17 : java.type:"float" = invoke %16 @loc="63:28" @java.ref:"hat.buffer.Float4::x():float";
120 private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) {
121 var here = OpTk.CallSite.of(this.getClass(), "vloadSelectPhase");
122 before(here, funcOp);
123 Stream<CodeElement<?, ?>> vectorSelectOps = funcOp.elements()
124 .mapMulti((codeElement, consumer) -> {
125 if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
126 if (isVectorOperation(invokeOp) && (invokeOp.resultType() != JavaType.VOID)) {
127 Value inputOperand = invokeOp.operands().getFirst();
128 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
129 consumer.accept(invokeOp);
130 consumer.accept(varLoadOp);
131 }
132 }
133 }
134 });
135
136 Set<CodeElement<?, ?>> nodesInvolved = vectorSelectOps.collect(Collectors.toSet());
137 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
138 CopyContext context = blockBuilder.context();
139 if (!nodesInvolved.contains(op)) {
140 blockBuilder.op(op);
141 } else if (op instanceof JavaOp.InvokeOp invokeOp) {
142 List<Value> inputInvokeOp = invokeOp.operands();
143 for (Value v : inputInvokeOp) {
144 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
145 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
146 int lane = getLane(invokeOp.invokeDescriptor().name());
147 HATVectorOp vSelectOp;
148 String name = findNameVector(varLoadOp);
149 if (invokeOp.resultType() != JavaType.VOID) {
150 vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp);
151 } else {
152 throw new RuntimeException("VSelect Load Op must return a value!");
153 }
154 Op.Result hatSelectResult = blockBuilder.op(vSelectOp);
155 vSelectOp.setLocation(invokeOp.location());
156 context.mapValue(invokeOp.result(), hatSelectResult);
157 }
158 }
159 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
160 // Pass the value
161 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
162 }
163 return blockBuilder;
164 });
165
166 after(here, funcOp);
167 return funcOp;
168 }
169
170 // Pattern from the code mode:
171 // %20 : java.type:"hat.buffer.Float4" = var.load %15 @loc="64:13";
172 // %21 : java.type:"float" = var.load %19 @loc="64:18";
173 // invoke %20 %21 @loc="64:13" @java.ref:"hat.buffer.Float4::x(float):void";
174 private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) {
175 var here = OpTk.CallSite.of(this.getClass(),"vstoreSelectPhase");
176 before(here, funcOp);
177 //TODO is this side table safe?
178 Stream<CodeElement<?, ?>> float4NodesInvolved = OpTk.elements(here,funcOp)
179 .mapMulti((codeElement, consumer) -> {
180 if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
181 if (isVectorOperation(invokeOp)) {
182 List<Value> inputOperandsInvoke = invokeOp.operands();
183 Value inputOperand = inputOperandsInvoke.getFirst();
184 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
185 consumer.accept(invokeOp);
186 consumer.accept(varLoadOp);
187 }
188 }
189 }
190 });
191
192 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
193 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
194 CopyContext context = blockBuilder.context();
195 if (!nodesInvolved.contains(op)) {
196 blockBuilder.op(op);
197 } else if (op instanceof JavaOp.InvokeOp invokeOp) {
198 List<Value> inputInvokeOp = invokeOp.operands();
199 Value v = inputInvokeOp.getFirst();
200
201 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
202 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
203 int lane = getLane(invokeOp.invokeDescriptor().name());
204 HATVectorOp vSelectOp;
205 String name = findNameVector(varLoadOp);
206 if (invokeOp.resultType() == JavaType.VOID) {
207 // The operand 1 in the store is the address (lane)
208 // The operand 1 in the store is the storeValue
209 CoreOp.VarOp resultOp = findVarOp(outputOperandsInvokeOp.get(1));
210 vSelectOp = new HATVectorSelectStoreOp(name, invokeOp.resultType(), lane, resultOp, outputOperandsInvokeOp);
211 } else {
212 throw new RuntimeException("VSelect Store Op must return a value!");
213 }
214 Op.Result resultVStore = blockBuilder.op(vSelectOp);
215 vSelectOp.setLocation(invokeOp.location());
216 context.mapValue(invokeOp.result(), resultVStore);
217 }
218
219 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
220 // Pass the value
221 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
222 }
223 return blockBuilder;
224 });
225
226 after(here, funcOp);
227 return funcOp;
228 }
229
230 @Override
231 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
232 funcOp = vloadSelectPhase(funcOp);
233 funcOp = vstoreSelectPhase(funcOp);
234 return funcOp;
235 }
236
237 }