1 /*
  2  * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.dialect;
 26 
 27 import jdk.incubator.code.Op;
 28 import jdk.incubator.code.TypeElement;
 29 import jdk.incubator.code.Value;
 30 import jdk.incubator.code.dialect.core.CoreOp;
 31 import jdk.incubator.code.dialect.java.JavaOp;
 32 import jdk.incubator.code.dialect.java.JavaType;
 33 
 34 import java.lang.reflect.Method;
 35 import java.util.Optional;
 36 import java.util.concurrent.atomic.AtomicInteger;
 37 import java.util.concurrent.atomic.AtomicReference;
 38 import java.util.stream.Stream;
 39 
 40 public class HATPhaseUtils {
 41 
 42     public static TypeElement getVectorElementType(String primitive) {
 43         return switch (primitive) {
 44             case "float" -> JavaType.FLOAT;
 45             case "double" -> JavaType.DOUBLE;
 46             case "int" -> JavaType.INT;
 47             case "long" -> JavaType.LONG;
 48             case "short" -> JavaType.SHORT;
 49             case "byte" -> JavaType.BYTE;
 50             case "char" -> JavaType.CHAR;
 51             case "boolean" -> JavaType.BOOLEAN;
 52             default -> null;
 53         };
 54     }
 55 
 56     public record VectorMetaData(TypeElement vectorTypeElement, int lanes) {
 57     }
 58 
 59     public static VectorMetaData getVectorTypeInfo(JavaOp.InvokeOp invokeOp, int param) {
 60         Value varValue = invokeOp.operands().get(param);
 61         if (varValue instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
 62             return getVectorTypeInfoWithCodeReflection(varLoadOp.resultType());
 63         }
 64         return null;
 65     }
 66 
 67     public static VectorMetaData getVectorTypeInfo(JavaOp.InvokeOp invokeOp) {
 68         return getVectorTypeInfoWithCodeReflection(invokeOp.resultType());
 69     }
 70 
 71     private static CoreOp.FuncOp buildCodeModelFor(Class<?> klass, String methodName) {
 72         Optional<Method> methodFunction = Stream.of(klass.getMethods())
 73                 .filter(m -> m.getName().equals(methodName))
 74                 .findFirst();
 75         return Op.ofMethod(methodFunction.get()).get();
 76     }
 77 
 78     /**
 79      * This method inspects the Vector Type Methods to obtain two methods for code-model:
 80      * 1) Method `type` to obtain the primitive base type of the vector type.
 81      * 2) Method `width` to obtain the number of lanes.
 82      *
 83      * @param typeElement
 84      *  {@link TypeElement}
 85      * @return
 86      * {@link VectorMetaData}
 87      */
 88     public static VectorMetaData getVectorTypeInfoWithCodeReflection(TypeElement typeElement) {
 89         Class<?> aClass;
 90         try {
 91             aClass = Class.forName(typeElement.toString());
 92         } catch (ClassNotFoundException e) {
 93             // TODO: Add control for exceptions in HAT (HATExceptions Handler)
 94             throw new RuntimeException(e);
 95         }
 96         CoreOp.FuncOp codeModelType = buildCodeModelFor(aClass, "type");
 97         AtomicReference<TypeElement> vectorElement = new AtomicReference<>();
 98         codeModelType.elements().forEach(codeElement -> {
 99             if (codeElement instanceof CoreOp.ReturnOp returnOp) {
100                 Value v = returnOp.operands().getFirst();
101                 if (v instanceof Op.Result r && r.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
102                     String primitiveTypeName = fieldLoadOp.fieldDescriptor().name();
103                     vectorElement.set(getVectorElementType(primitiveTypeName.toLowerCase()));
104                 }
105             }
106         });
107 
108         AtomicInteger lanes = new AtomicInteger(1);
109         CoreOp.FuncOp codeModelWidth = buildCodeModelFor(aClass, "width");
110         codeModelWidth.elements().forEach(codeElement -> {
111             if (codeElement instanceof CoreOp.ReturnOp returnOp) {
112                 Value v = returnOp.operands().getFirst();
113                 if (v instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp constantOp) {
114                     lanes.set((Integer) constantOp.value());
115                 }
116             }
117         });
118         return new VectorMetaData(vectorElement.get(), lanes.get());
119     }
120 
121     public static int getWitdh(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
122         return getWitdh(varLoadOp.operands().getFirst());
123     }
124 
125     public static int getWitdh(Value v) {
126         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
127             return getWitdh(varLoadOp);
128         } else {
129             // Leaf of tree -
130             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
131                 return hatVectorOp.vectorN();
132             }
133             return -1;
134         }
135     }
136 
137     public static TypeElement findVectorTypeElement(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
138         return findVectorTypeElement(varLoadOp.operands().getFirst());
139     }
140 
141     public static TypeElement findVectorTypeElement(Value v) {
142         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
143             return findVectorTypeElement(varLoadOp);
144         } else {
145             // Leaf of tree -
146             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
147                 return hatVectorOp.vectorElementType;
148             }
149             return null;
150         }
151     }
152 
153     public static String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
154         return findNameVector(varLoadOp.operands().getFirst());
155     }
156 
157     public static String findNameVector(Value v) {
158         if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
159             return findNameVector(varLoadOp);
160         } else {
161             // Leaf of tree -
162             if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
163                 return hatVectorOp.varName();
164             }
165             return null;
166         }
167     }
168 
169 }