1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.phases;
26
27 import hat.Accelerator;
28 import hat.dialect.HATVectorSelectLoadOp;
29 import hat.dialect.HATVectorSelectStoreOp;
30 import hat.dialect.HATVectorOp;
31 import hat.optools.OpTk;
32 import hat.types._V;
33 import jdk.incubator.code.CodeElement;
34 import jdk.incubator.code.CodeContext;
35 import jdk.incubator.code.Op;
36 import jdk.incubator.code.Value;
37 import jdk.incubator.code.dialect.core.CoreOp;
38 import jdk.incubator.code.dialect.java.JavaOp;
39 import jdk.incubator.code.dialect.java.JavaType;
40
41 import java.util.List;
42 import java.util.Set;
43 import java.util.stream.Collectors;
44 import java.util.stream.Stream;
45
46 public class HATDialectifyVectorSelectPhase implements HATDialect {
47
48 protected final Accelerator accelerator;
49 @Override public Accelerator accelerator(){
50 return this.accelerator;
51 }
52 public HATDialectifyVectorSelectPhase(Accelerator accelerator) {
53 this.accelerator = accelerator;
54 }
55
56 private boolean isVectorLane(JavaOp.InvokeOp invokeOp) {
57 return isMethod(invokeOp, "x")
58 || isMethod(invokeOp, "y")
59 || isMethod(invokeOp, "z")
60 || isMethod(invokeOp, "w");
61 }
62
63 int getLane(String fieldName) {
64 return switch (fieldName) {
65 case "x" -> 0;
66 case "y" -> 1;
67 case "z" -> 2;
68 case "w" -> 3;
69 default -> -1;
70 };
71 }
72
73 private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) {
74 String typeElement = invokeOp.invokeDescriptor().refType().toString();
75 Set<Class<?>> interfaces;
76 try {
77 Class<?> aClass = Class.forName(typeElement);
78 interfaces = OpTk.inspectAllInterfaces(aClass);
79 } catch (ClassNotFoundException _) {
80 return false;
81 }
82 return interfaces.contains(_V.class) && isVectorLane(invokeOp);
83 }
84
85 private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
86 return findNameVector(varLoadOp.operands().get(0));
87 }
88
89 private String findNameVector(Value v) {
90 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
91 return findNameVector(varLoadOp);
92 } else {
93 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp vectorViewOp) {
94 return vectorViewOp.varName();
95 }
96 return null;
97 }
98 }
99
100 private CoreOp.VarOp findVarOp(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
101 return findVarOp(varLoadOp.operands().get(0));
102 }
103
104 private CoreOp.VarOp findVarOp(Value v) {
105 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
106 return findVarOp(varLoadOp);
107 } else {
108 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.VarOp varOp) {
109 return varOp;
110 }
111 return null;
112 }
113 }
114
115 // Code Model Pattern:
116 // %16 : java.type:"hat.buffer.Float4" = var.load %15 @loc="63:28";
117 // %17 : java.type:"float" = invoke %16 @loc="63:28" @java.ref:"hat.buffer.Float4::x():float";
118 private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) {
119 var here = OpTk.CallSite.of(this.getClass(), "vloadSelectPhase");
120 before(here, funcOp);
121 Stream<CodeElement<?, ?>> vectorSelectOps = funcOp.elements()
122 .mapMulti((codeElement, consumer) -> {
123 if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
124 if (isVectorOperation(invokeOp) && (invokeOp.resultType() != JavaType.VOID)) {
125 Value inputOperand = invokeOp.operands().getFirst();
126 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
127 consumer.accept(invokeOp);
128 consumer.accept(varLoadOp);
129 }
130 }
131 }
132 });
133
134 Set<CodeElement<?, ?>> nodesInvolved = vectorSelectOps.collect(Collectors.toSet());
135 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
136 CodeContext context = blockBuilder.context();
137 if (!nodesInvolved.contains(op)) {
138 blockBuilder.op(op);
139 } else if (op instanceof JavaOp.InvokeOp invokeOp) {
140 List<Value> inputInvokeOp = invokeOp.operands();
141 for (Value v : inputInvokeOp) {
142 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
143 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
144 int lane = getLane(invokeOp.invokeDescriptor().name());
145 HATVectorOp vSelectOp;
146 String name = findNameVector(varLoadOp);
147 if (invokeOp.resultType() != JavaType.VOID) {
148 vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp);
149 } else {
150 throw new RuntimeException("VSelect Load Op must return a value!");
151 }
152 Op.Result hatSelectResult = blockBuilder.op(vSelectOp);
153 vSelectOp.setLocation(invokeOp.location());
154 context.mapValue(invokeOp.result(), hatSelectResult);
155 }
156 }
157 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
158 // Pass the value
159 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
160 }
161 return blockBuilder;
162 });
163
164 after(here, funcOp);
165 return funcOp;
166 }
167
168 // Pattern from the code mode:
169 // %20 : java.type:"hat.buffer.Float4" = var.load %15 @loc="64:13";
170 // %21 : java.type:"float" = var.load %19 @loc="64:18";
171 // invoke %20 %21 @loc="64:13" @java.ref:"hat.buffer.Float4::x(float):void";
172 private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) {
173 var here = OpTk.CallSite.of(this.getClass(),"vstoreSelectPhase");
174 before(here, funcOp);
175 //TODO is this side table safe?
176 Stream<CodeElement<?, ?>> float4NodesInvolved = OpTk.elements(here,funcOp)
177 .mapMulti((codeElement, consumer) -> {
178 if (codeElement instanceof JavaOp.InvokeOp invokeOp) {
179 if (isVectorOperation(invokeOp)) {
180 List<Value> inputOperandsInvoke = invokeOp.operands();
181 Value inputOperand = inputOperandsInvoke.getFirst();
182 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
183 consumer.accept(invokeOp);
184 consumer.accept(varLoadOp);
185 }
186 }
187 }
188 });
189
190 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet());
191 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> {
192 CodeContext context = blockBuilder.context();
193 if (!nodesInvolved.contains(op)) {
194 blockBuilder.op(op);
195 } else if (op instanceof JavaOp.InvokeOp invokeOp) {
196 List<Value> inputInvokeOp = invokeOp.operands();
197 Value v = inputInvokeOp.getFirst();
198
199 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
200 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp);
201 int lane = getLane(invokeOp.invokeDescriptor().name());
202 HATVectorOp vSelectOp;
203 String name = findNameVector(varLoadOp);
204 if (invokeOp.resultType() == JavaType.VOID) {
205 // The operand 1 in the store is the address (lane)
206 // The operand 1 in the store is the storeValue
207 CoreOp.VarOp resultOp = findVarOp(outputOperandsInvokeOp.get(1));
208 vSelectOp = new HATVectorSelectStoreOp(name, invokeOp.resultType(), lane, resultOp, outputOperandsInvokeOp);
209 } else {
210 throw new RuntimeException("VSelect Store Op must return a value!");
211 }
212 Op.Result resultVStore = blockBuilder.op(vSelectOp);
213 vSelectOp.setLocation(invokeOp.location());
214 context.mapValue(invokeOp.result(), resultVStore);
215 }
216
217 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
218 // Pass the value
219 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst()));
220 }
221 return blockBuilder;
222 });
223
224 after(here, funcOp);
225 return funcOp;
226 }
227
228 @Override
229 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) {
230 funcOp = vloadSelectPhase(funcOp);
231 funcOp = vstoreSelectPhase(funcOp);
232 return funcOp;
233 }
234
235 }