1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.phases; 26 27 import hat.Accelerator; 28 import hat.dialect.HATLocalVarOp; 29 import hat.dialect.HATPrivateVarOp; 30 import hat.dialect.HATVectorAddOp; 31 import hat.dialect.HATVectorDivOp; 32 import hat.dialect.HATVectorLoadOp; 33 import hat.dialect.HATVectorMulOp; 34 import hat.dialect.HATVectorSubOp; 35 import hat.dialect.HATVectorVarLoadOp; 36 import hat.dialect.HATVectorVarOp; 37 import hat.dialect.HATVectorViewOp; 38 import hat.dialect.HATVectorBinaryOp; 39 import hat.optools.OpTk; 40 import jdk.incubator.code.CodeElement; 41 import jdk.incubator.code.CopyContext; 42 import jdk.incubator.code.Op; 43 import jdk.incubator.code.TypeElement; 44 import jdk.incubator.code.Value; 45 import jdk.incubator.code.dialect.core.CoreOp; 46 import jdk.incubator.code.dialect.java.JavaOp; 47 48 import java.util.HashMap; 49 import java.util.List; 50 import java.util.Map; 51 import java.util.Objects; 52 import java.util.Set; 53 import java.util.stream.Collectors; 54 import java.util.stream.Stream; 55 56 public abstract class HATDialectifyVectorOpPhase implements HATDialect{ 57 58 protected final Accelerator accelerator; 59 @Override public Accelerator accelerator(){ 60 return this.accelerator; 61 } 62 private final OpView vectorOperation; 63 64 public HATDialectifyVectorOpPhase(Accelerator accelerator, OpView vectorOperation) { 65 this.accelerator = accelerator; 66 this.vectorOperation = vectorOperation; 67 } 68 69 70 71 private HATVectorBinaryOp.OpType getBinaryOpType(JavaOp.InvokeOp invokeOp) { 72 return switch (invokeOp.invokeDescriptor().name()) { 73 case "add" -> HATVectorBinaryOp.OpType.ADD; 74 case "sub" -> HATVectorBinaryOp.OpType.SUB; 75 case "mul" -> HATVectorBinaryOp.OpType.MUL; 76 case "div" -> HATVectorBinaryOp.OpType.DIV; 77 default -> throw new RuntimeException("Unknown binary op " + invokeOp.invokeDescriptor().name()); 78 }; 79 } 80 81 public enum OpView { 82 FLOAT4_LOAD("float4View"), 83 ADD("add"), 84 SUB("sub"), 85 MUL("mul"), 86 DIV("div"); 87 final String methodName; 88 OpView(String methodName) { 89 this.methodName = methodName; 90 } 91 } 92 93 private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) { 94 TypeElement typeElement = invokeOp.resultType(); 95 boolean isHatVectorType = typeElement.toString().startsWith("hat.buffer.Float"); 96 return isHatVectorType 97 && OpTk.isIfaceBufferMethod(accelerator.lookup, invokeOp) 98 && isMethod(invokeOp, vectorOperation.methodName); 99 } 100 101 private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 102 return findNameVector(varLoadOp.operands().get(0)); 103 } 104 105 private String findNameVector(Value v) { 106 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 107 return findNameVector(varLoadOp); 108 } else { 109 // Leaf of tree - 110 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) { 111 return hatVectorViewOp.varName(); 112 } 113 return null; 114 } 115 } 116 117 private boolean findIsSharedOrPrivate(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 118 return findIsSharedOrPrivate(varLoadOp.operands().get(0)); 119 } 120 121 private boolean findIsSharedOrPrivate(Value v) { 122 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 123 return findIsSharedOrPrivate(varLoadOp); 124 } else { 125 // Leaf of tree - 126 if (v instanceof CoreOp.Result r && (r.op() instanceof HATLocalVarOp || r.op() instanceof HATPrivateVarOp)) { 127 return true; 128 } 129 return false; 130 } 131 } 132 133 private HATVectorBinaryOp buildVectorBinaryOp(HATVectorBinaryOp.OpType opType, String varName, TypeElement resultType, List<Value> outputOperands) { 134 return switch (opType) { 135 case ADD -> new HATVectorAddOp(varName, resultType, outputOperands); 136 case SUB -> new HATVectorSubOp(varName, resultType, outputOperands); 137 case MUL -> new HATVectorMulOp(varName, resultType, outputOperands); 138 case DIV -> new HATVectorDivOp(varName, resultType, outputOperands); 139 }; 140 } 141 142 private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { 143 var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorLoad" ); 144 before(here,funcOp); 145 Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements() 146 .mapMulti((codeElement, consumer) -> { 147 if (codeElement instanceof CoreOp.VarOp varOp) { 148 List<Value> inputOperandsVarOp = varOp.operands(); 149 for (Value inputOperand : inputOperandsVarOp) { 150 if (inputOperand instanceof Op.Result result) { 151 if (result.op() instanceof JavaOp.InvokeOp invokeOp) { 152 if (isVectorOperation(invokeOp)) { 153 consumer.accept(invokeOp); 154 consumer.accept(varOp); 155 } 156 } 157 } 158 } 159 } 160 }); 161 162 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); 163 164 165 funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { 166 CopyContext context = blockBuilder.context(); 167 if (!nodesInvolved.contains(op)) { 168 blockBuilder.op(op); 169 } else if (op instanceof JavaOp.InvokeOp invokeOp) { 170 // Don't insert the invoke node 171 Op.Result result = invokeOp.result(); 172 List<Op.Result> collect = result.uses().stream().toList(); 173 boolean isShared = findIsSharedOrPrivate(invokeOp.operands().getFirst()); 174 for (Op.Result r : collect) { 175 if (r.op() instanceof CoreOp.VarOp varOp) { 176 List<Value> inputOperandsVarOp = invokeOp.operands(); 177 List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp); 178 HATVectorViewOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp); 179 Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); 180 memoryViewOp.setLocation(varOp.location()); 181 context.mapValue(invokeOp.result(), hatLocalResult); 182 } 183 } 184 } else if (op instanceof CoreOp.VarOp varOp) { 185 // pass value 186 //context.mapValue(varOp.result(), context.getValue(varOp.operands().getFirst())); 187 List<Value> inputOperandsVarOp = varOp.operands(); 188 List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp); 189 HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); 190 Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); 191 memoryViewOp.setLocation(varOp.location()); 192 context.mapValue(varOp.result(), hatLocalResult); 193 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 194 // pass value 195 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst())); 196 } 197 return blockBuilder; 198 }); 199 after(here, funcOp); 200 return funcOp; 201 } 202 203 private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { 204 var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorBinaryOps"); 205 before(here, funcOp); 206 Map<JavaOp.InvokeOp, HATVectorBinaryOp.OpType> binaryOperation = new HashMap<>(); 207 208 Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements() 209 .mapMulti((codeElement, consumer) -> { 210 if (codeElement instanceof CoreOp.VarOp varOp) { 211 List<Value> inputOperandsVarOp = varOp.operands(); 212 for (Value inputOperand : inputOperandsVarOp) { 213 if (inputOperand instanceof Op.Result result) { 214 if (result.op() instanceof JavaOp.InvokeOp invokeOp) { 215 if (isVectorOperation(invokeOp)) { 216 HATVectorBinaryOp.OpType binaryOpType = getBinaryOpType(invokeOp); 217 binaryOperation.put(invokeOp, binaryOpType); 218 consumer.accept(invokeOp); 219 consumer.accept(varOp); 220 } 221 } 222 } 223 } 224 } 225 }); 226 227 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); 228 229 funcOp = OpTk.transform(here, funcOp, nodesInvolved::contains, (blockBuilder, op) -> { 230 CopyContext context = blockBuilder.context(); 231 // if (!nodesInvolved.contains(op)) { 232 // blockBuilder.op(op); 233 //} else 234 if (op instanceof JavaOp.InvokeOp invokeOp) { 235 Op.Result result = invokeOp.result(); 236 List<Value> inputOperands = invokeOp.operands(); 237 List<Value> outputOperands = context.getValues(inputOperands); 238 List<Op.Result> collect = result.uses().stream().toList(); 239 for (Op.Result r : collect) { 240 if (r.op() instanceof CoreOp.VarOp varOp) { 241 HATVectorBinaryOp.OpType binaryOpType = binaryOperation.get(invokeOp); 242 HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands); 243 Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); 244 memoryViewOp.setLocation(varOp.location()); 245 context.mapValue(invokeOp.result(), hatVectorOpResult); 246 break; 247 } 248 } 249 } else if (op instanceof CoreOp.VarOp varOp) { 250 List<Value> inputOperandsVarOp = varOp.operands(); 251 List<Value> outputOperandsVarOp = context.getValues(inputOperandsVarOp); 252 HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); 253 Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); 254 memoryViewOp.setLocation(varOp.location()); 255 context.mapValue(varOp.result(), hatVectorResult); 256 } 257 return blockBuilder; 258 }); 259 after(here,funcOp); 260 return funcOp; 261 } 262 263 private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp funcOp) { 264 var here = OpTk.CallSite.of(this.getClass(), "dialectifyBinaryWithConcatenation"); 265 before(here, funcOp); 266 Map<JavaOp.InvokeOp, HATVectorBinaryOp.OpType> binaryOperation = new HashMap<>(); 267 Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements() 268 .mapMulti((codeElement, consumer) -> { 269 if (codeElement instanceof JavaOp.InvokeOp invokeOp) { 270 if (isVectorOperation(invokeOp)) { 271 List<Value> inputOperandsInvoke = invokeOp.operands(); 272 for (Value inputOperand : inputOperandsInvoke) { 273 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 274 HATVectorBinaryOp.OpType binaryOpType = getBinaryOpType(invokeOp); 275 binaryOperation.put(invokeOp, binaryOpType); 276 consumer.accept(varLoadOp); 277 consumer.accept(invokeOp); 278 } 279 } 280 } 281 } else if (codeElement instanceof HATVectorBinaryOp hatVectorBinaryOp) { 282 List<Value> inputOperandsInvoke = hatVectorBinaryOp.operands(); 283 for (Value inputOperand : inputOperandsInvoke) { 284 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 285 consumer.accept(varLoadOp); 286 } 287 } 288 } 289 }); 290 291 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); 292 if (nodesInvolved.isEmpty()) { 293 return funcOp; 294 } 295 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> { 296 CopyContext context = blockBuilder.context(); 297 if (!nodesInvolved.contains(op)) { 298 blockBuilder.op(op); 299 } else if (op instanceof JavaOp.InvokeOp invokeOp) { 300 List<Value> inputOperands = invokeOp.operands(); 301 List<Value> outputOperands = context.getValues(inputOperands); 302 HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands); 303 Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); 304 memoryViewOp.setLocation(invokeOp.location()); 305 context.mapValue(invokeOp.result(), hatVectorOpResult); 306 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 307 List<Value> inputOperandsVarLoad = varLoadOp.operands(); 308 List<Value> outputOperandsVarLoad = context.getValues(inputOperandsVarLoad); 309 String varLoadName = findNameVector(varLoadOp); 310 HATVectorViewOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad); 311 Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); 312 memoryViewOp.setLocation(varLoadOp.location()); 313 context.mapValue(varLoadOp.result(), hatVectorResult); 314 } 315 return blockBuilder; 316 }); 317 after(here,funcOp); 318 return funcOp; 319 } 320 321 @Override 322 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { 323 if (Objects.requireNonNull(vectorOperation) == OpView.FLOAT4_LOAD) { 324 funcOp = dialectifyVectorLoad(funcOp); 325 } else { 326 // Find binary operations 327 funcOp = dialectifyVectorBinaryOps(funcOp); 328 funcOp = dialectifyVectorBinaryWithContatenationOps(funcOp); 329 } 330 return funcOp; 331 } 332 333 public static class AddPhase extends HATDialectifyVectorOpPhase{ 334 335 public AddPhase(Accelerator accelerator) { 336 super(accelerator, OpView.ADD); 337 } 338 } 339 340 public static class DivPhase extends HATDialectifyVectorOpPhase{ 341 342 public DivPhase(Accelerator accelerator) { 343 super(accelerator, OpView.DIV); 344 } 345 } 346 347 public static class Float4LoadPhase extends HATDialectifyVectorOpPhase{ 348 349 public Float4LoadPhase(Accelerator accelerator) { 350 super(accelerator, OpView.FLOAT4_LOAD); 351 } 352 } 353 354 public static class MulPhase extends HATDialectifyVectorOpPhase{ 355 356 public MulPhase(Accelerator accelerator) { 357 super(accelerator, OpView.MUL); 358 } 359 } 360 361 public static class SubPhase extends HATDialectifyVectorOpPhase{ 362 363 public SubPhase(Accelerator accelerator) { 364 super(accelerator, OpView.SUB); 365 } 366 } 367 }