1 /* 2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package hat.phases; 26 27 import hat.Accelerator; 28 import hat.dialect.HATVectorSelectLoadOp; 29 import hat.dialect.HATVectorSelectStoreOp; 30 import hat.dialect.HATVectorViewOp; 31 import hat.optools.OpTk; 32 import jdk.incubator.code.CodeElement; 33 import jdk.incubator.code.CopyContext; 34 import jdk.incubator.code.Op; 35 import jdk.incubator.code.Value; 36 import jdk.incubator.code.dialect.core.CoreOp; 37 import jdk.incubator.code.dialect.java.JavaOp; 38 import jdk.incubator.code.dialect.java.JavaType; 39 40 import java.util.List; 41 import java.util.Set; 42 import java.util.stream.Collectors; 43 import java.util.stream.Stream; 44 45 public class HATDialectifyVectorSelectPhase implements HATDialect { 46 47 protected final Accelerator accelerator; 48 @Override public Accelerator accelerator(){ 49 return this.accelerator; 50 } 51 public HATDialectifyVectorSelectPhase(Accelerator accelerator) { 52 this.accelerator = accelerator; 53 } 54 55 private boolean isVectorLane(JavaOp.InvokeOp invokeOp) { 56 return isMethod(invokeOp, "x") 57 || isMethod(invokeOp, "y") 58 || isMethod(invokeOp, "z") 59 || isMethod(invokeOp, "w"); 60 } 61 62 int getLane(String fieldName) { 63 return switch (fieldName) { 64 case "x" -> 0; 65 case "y" -> 1; 66 case "z" -> 2; 67 case "w" -> 3; 68 default -> -1; 69 }; 70 } 71 72 private boolean isVectorOperation(JavaOp.InvokeOp invokeOp) { 73 String invokeClass = invokeOp.invokeDescriptor().refType().toString(); 74 boolean isHatVectorType = invokeClass.startsWith("hat.buffer.Float"); 75 return isHatVectorType 76 && OpTk.isIfaceBufferMethod(accelerator.lookup, invokeOp) 77 && (isVectorLane(invokeOp)); 78 } 79 80 private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 81 return findNameVector(varLoadOp.operands().get(0)); 82 } 83 84 private String findNameVector(Value v) { 85 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 86 return findNameVector(varLoadOp); 87 } else { 88 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp vectorViewOp) { 89 return vectorViewOp.varName(); 90 } 91 return null; 92 } 93 } 94 95 private CoreOp.VarOp findVarOp(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 96 return findVarOp(varLoadOp.operands().get(0)); 97 } 98 99 private CoreOp.VarOp findVarOp(Value v) { 100 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 101 return findVarOp(varLoadOp); 102 } else { 103 if (v instanceof CoreOp.Result r && r.op() instanceof CoreOp.VarOp varOp) { 104 return varOp; 105 } 106 return null; 107 } 108 } 109 110 111 // Code Model Pattern: 112 // %16 : java.type:"hat.buffer.Float4" = var.load %15 @loc="63:28"; 113 // %17 : java.type:"float" = invoke %16 @loc="63:28" @java.ref:"hat.buffer.Float4::x():float"; 114 115 private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) { 116 var here = OpTk.CallSite.of(this.getClass(), "vloadSelectPhase"); 117 before(here, funcOp); 118 Stream<CodeElement<?, ?>> float4NodesInvolved = funcOp.elements() 119 .mapMulti((codeElement, consumer) -> { 120 if (codeElement instanceof JavaOp.InvokeOp invokeOp) { 121 if (isVectorOperation(invokeOp) && invokeOp.resultType() != JavaType.VOID) { 122 List<Value> inputOperandsInvoke = invokeOp.operands(); 123 Value inputOperand = inputOperandsInvoke.getFirst(); 124 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 125 consumer.accept(invokeOp); 126 consumer.accept(varLoadOp); 127 } 128 } 129 } 130 }); 131 132 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); 133 134 funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { 135 CopyContext context = blockBuilder.context(); 136 if (!nodesInvolved.contains(op)) { 137 blockBuilder.op(op); 138 } else if (op instanceof JavaOp.InvokeOp invokeOp) { 139 List<Value> inputInvokeOp = invokeOp.operands(); 140 for (Value v : inputInvokeOp) { 141 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 142 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp); 143 int lane = getLane(invokeOp.invokeDescriptor().name()); 144 HATVectorViewOp vSelectOp; 145 String name = findNameVector(varLoadOp); 146 if (invokeOp.resultType() != JavaType.VOID) { 147 vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp); 148 } else { 149 throw new RuntimeException("VSelect Load Op must return a value!"); 150 } 151 Op.Result hatSelectResult = blockBuilder.op(vSelectOp); 152 vSelectOp.setLocation(invokeOp.location()); 153 context.mapValue(invokeOp.result(), hatSelectResult); 154 } 155 } 156 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 157 // Pass the value 158 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst())); 159 } 160 return blockBuilder; 161 }); 162 163 after(here,funcOp); 164 return funcOp; 165 } 166 167 // Pattern from the code mode: 168 // %20 : java.type:"hat.buffer.Float4" = var.load %15 @loc="64:13"; 169 // %21 : java.type:"float" = var.load %19 @loc="64:18"; 170 // invoke %20 %21 @loc="64:13" @java.ref:"hat.buffer.Float4::x(float):void"; 171 private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) { 172 var here = OpTk.CallSite.of(this.getClass(),"vstoreSelectPhase"); 173 before(here, funcOp); 174 //TODO is this side table safe? 175 Stream<CodeElement<?, ?>> float4NodesInvolved = OpTk.elements(here,funcOp) 176 .mapMulti((codeElement, consumer) -> { 177 if (codeElement instanceof JavaOp.InvokeOp invokeOp) { 178 if (isVectorOperation(invokeOp)) { 179 List<Value> inputOperandsInvoke = invokeOp.operands(); 180 Value inputOperand = inputOperandsInvoke.getFirst(); 181 if (inputOperand instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 182 consumer.accept(invokeOp); 183 consumer.accept(varLoadOp); 184 } 185 } 186 } 187 }); 188 189 Set<CodeElement<?, ?>> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); 190 funcOp = OpTk.transform(here, funcOp, (blockBuilder, op) -> { 191 CopyContext context = blockBuilder.context(); 192 if (!nodesInvolved.contains(op)) { 193 blockBuilder.op(op); 194 } else if (op instanceof JavaOp.InvokeOp invokeOp) { 195 List<Value> inputInvokeOp = invokeOp.operands(); 196 Value v = inputInvokeOp.getFirst(); 197 198 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 199 List<Value> outputOperandsInvokeOp = context.getValues(inputInvokeOp); 200 int lane = getLane(invokeOp.invokeDescriptor().name()); 201 HATVectorViewOp vSelectOp; 202 String name = findNameVector(varLoadOp); 203 if (invokeOp.resultType() == JavaType.VOID) { 204 // The operand 1 in the store is the address (lane) 205 // The operand 1 in the store is the storeValue 206 CoreOp.VarOp resultOp = findVarOp(outputOperandsInvokeOp.get(1)); 207 vSelectOp = new HATVectorSelectStoreOp(name, invokeOp.resultType(), lane, resultOp, outputOperandsInvokeOp); 208 } else { 209 throw new RuntimeException("VSelect Store Op must return a value!"); 210 } 211 Op.Result resultVStore = blockBuilder.op(vSelectOp); 212 vSelectOp.setLocation(invokeOp.location()); 213 context.mapValue(invokeOp.result(), resultVStore); 214 } 215 216 } else if (op instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { 217 // Pass the value 218 context.mapValue(varLoadOp.result(), context.getValue(varLoadOp.operands().getFirst())); 219 } 220 return blockBuilder; 221 }); 222 223 after(here, funcOp); 224 return funcOp; 225 } 226 227 @Override 228 public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { 229 funcOp = vloadSelectPhase(funcOp); 230 funcOp = vstoreSelectPhase(funcOp); 231 return funcOp; 232 } 233 234 }