1 /*
2 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25 package hat.dialect;
26
27 import jdk.incubator.code.Op;
28 import jdk.incubator.code.TypeElement;
29 import jdk.incubator.code.Value;
30 import jdk.incubator.code.dialect.core.CoreOp;
31 import jdk.incubator.code.dialect.java.JavaOp;
32 import jdk.incubator.code.dialect.java.JavaType;
33
34 import java.lang.reflect.Method;
35 import java.util.Optional;
36 import java.util.concurrent.atomic.AtomicInteger;
37 import java.util.concurrent.atomic.AtomicReference;
38 import java.util.stream.Stream;
39
40 public class HATPhaseUtils {
41
42 public static TypeElement getVectorElementType(String primitive) {
43 return switch (primitive) {
44 case "float" -> JavaType.FLOAT;
45 case "double" -> JavaType.DOUBLE;
46 case "int" -> JavaType.INT;
47 case "long" -> JavaType.LONG;
48 case "short" -> JavaType.SHORT;
49 case "byte" -> JavaType.BYTE;
50 case "char" -> JavaType.CHAR;
51 case "boolean" -> JavaType.BOOLEAN;
52 default -> null;
53 };
54 }
55
56 public record VectorMetaData(TypeElement vectorTypeElement, int lanes) {
57 }
58
59 public static VectorMetaData getVectorTypeInfo(JavaOp.InvokeOp invokeOp, int param) {
60 Value varValue = invokeOp.operands().get(param);
61 if (varValue instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
62 return getVectorTypeInfoWithCodeReflection(varLoadOp.resultType());
63 }
64 return null;
65 }
66
67 public static VectorMetaData getVectorTypeInfo(JavaOp.InvokeOp invokeOp) {
68 return getVectorTypeInfoWithCodeReflection(invokeOp.resultType());
69 }
70
71 private static CoreOp.FuncOp buildCodeModelFor(Class<?> klass, String methodName) {
72 Optional<Method> methodFunction = Stream.of(klass.getMethods())
73 .filter(m -> m.getName().equals(methodName))
74 .findFirst();
75 return Op.ofMethod(methodFunction.get()).get();
76 }
77
78 /**
79 * This method inspects the Vector Type Methods to obtain two methods for code-model:
80 * 1) Method `type` to obtain the primitive base type of the vector type.
81 * 2) Method `width` to obtain the number of lanes.
82 *
83 * @param typeElement
84 * {@link TypeElement}
85 * @return
86 * {@link VectorMetaData}
87 */
88 public static VectorMetaData getVectorTypeInfoWithCodeReflection(TypeElement typeElement) {
89 Class<?> aClass;
90 try {
91 aClass = Class.forName(typeElement.toString());
92 } catch (ClassNotFoundException e) {
93 // TODO: Add control for exceptions in HAT (HATExceptions Handler)
94 throw new RuntimeException(e);
95 }
96 CoreOp.FuncOp codeModelType = buildCodeModelFor(aClass, "type");
97 AtomicReference<TypeElement> vectorElement = new AtomicReference<>();
98 codeModelType.elements().forEach(codeElement -> {
99 if (codeElement instanceof CoreOp.ReturnOp returnOp) {
100 Value v = returnOp.operands().getFirst();
101 if (v instanceof Op.Result r && r.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
102 String primitiveTypeName = fieldLoadOp.fieldDescriptor().name();
103 vectorElement.set(getVectorElementType(primitiveTypeName.toLowerCase()));
104 }
105 }
106 });
107
108 AtomicInteger lanes = new AtomicInteger(1);
109 CoreOp.FuncOp codeModelWidth = buildCodeModelFor(aClass, "width");
110 codeModelWidth.elements().forEach(codeElement -> {
111 if (codeElement instanceof CoreOp.ReturnOp returnOp) {
112 Value v = returnOp.operands().getFirst();
113 if (v instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp constantOp) {
114 lanes.set((Integer) constantOp.value());
115 }
116 }
117 });
118 return new VectorMetaData(vectorElement.get(), lanes.get());
119 }
120
121 public static int getWitdh(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
122 return getWitdh(varLoadOp.operands().getFirst());
123 }
124
125 public static int getWitdh(Value v) {
126 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
127 return getWitdh(varLoadOp);
128 } else {
129 // Leaf of tree -
130 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
131 return hatVectorOp.vectorN();
132 }
133 return -1;
134 }
135 }
136
137 public static TypeElement findVectorTypeElement(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
138 return findVectorTypeElement(varLoadOp.operands().getFirst());
139 }
140
141 public static TypeElement findVectorTypeElement(Value v) {
142 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
143 return findVectorTypeElement(varLoadOp);
144 } else {
145 // Leaf of tree -
146 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
147 return hatVectorOp.vectorElementType;
148 }
149 return null;
150 }
151 }
152
153 public static String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
154 return findNameVector(varLoadOp.operands().getFirst());
155 }
156
157 public static String findNameVector(Value v) {
158 if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) {
159 return findNameVector(varLoadOp);
160 } else {
161 // Leaf of tree -
162 if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) {
163 return hatVectorOp.varName();
164 }
165 return null;
166 }
167 }
168
169 }