1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat;
27
28 import java.util.Arrays;
29 import java.util.List;
30 import java.util.Optional;
31
32 public class Config {
33
34 public record Bit(int index, int size, String name, String alt, String description) implements Comparable<Bit> {
35 static Bit of(int index, int size, String name, String alt, String description){
36 return new Bit(index,size,name, alt, description);
37 }
38 public static Bit of(int index, String name, String alt, String description){
39 return new Bit(index,1,name, alt,description);
40 }
41
42 public static Bit nextBit(Bit bit, int size, String name, String alt,String description){
43 return new Bit(bit.index+bit.size,size,name, alt, description);
44 }
45 public static Bit nextBit(Bit bit, String name, String alt, String description){
46 return nextBit(bit, 1,name,alt, description);
47 }
48 @Override
49 public int compareTo(Bit bit) {
50 return Integer.compare(index, bit.index);
51 }
52
53 public boolean isBitSet(int bits){
54 return (mask()&bits) == mask();
55 }
56 public boolean isSet(Config config){
57 return (mask()&config.bits) == mask();
58 }
59 public int mask(){
60 return ((1<<size)-1) << index;
61 }
62
63 public String maskString(){
64 return Integer.toBinaryString(mask());
65 }
66 }
67
68 public static final Bit PLATFORM = Bit.of(0,4, "PLATFORM","P", "FFI ONLY platform id (0-15)");
69 public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","D","FFI ONLY device id (0-15)");
70 private static final Bit MINIMIZE_COPIES = Bit.nextBit(DEVICE, "MINIMIZE_COPIES","MC","FFI ONLY Try to minimize copies");
71 public boolean minimizeCopies() {
72 return MINIMIZE_COPIES.isSet(this);
73 }
74 private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "T","FFI ONLY trace code");
75 private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "P","FFI ONLY Turn on profiling");
76 private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","SC","Show generated code (PTX/OpenCL/CUDA)");
77 public boolean showCode() {
78 return SHOW_CODE.isSet(this);
79 }
80 private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "SKM","Show (via OpWriter) Kernel Model");
81 public boolean showKernelModel() {
82 return SHOW_KERNEL_MODEL.isSet(this);
83 }
84 private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "SCM","Show (via OpWriter) Compute Model");
85 public boolean showComputeModel() {
86 return SHOW_COMPUTE_MODEL.isSet(this);
87 }
88 private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "SDI","FFI show platform and device info");
89 public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "I","INFO level logging");
90 public static final Bit WARN = Bit.nextBit(INFO, "WARN", "W","WARN(ing) level logging ");
91 public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "U","UNIT test level logging ");
92 private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "TC","FFI ONLY trace copies");
93 private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES,"TRACE_SKIPPED_COPIES", "TSC", "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
94 private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES","TE", "FFI ONLY trace enqueued tasks");
95 private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "TCALLS","FFI ONLY trace calls (enter/leave)");
96 private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "SW","FFI ONLY show why we decided to copy buffer (H to D)");
97 private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "SS","Show iface buffer state changes");
98 public boolean showState(){return SHOW_STATE.isSet(this);}
99 private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "PTX","FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
100 public boolean ptx(){return PTX.isSet(this);}
101 private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "I","Interpret the code model rather than converting to bytecode");
102 public boolean interpret() {
103 return INTERPRET.isSet(this);
104 }
105 private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "H","Don't show UI");
106
107 public boolean headless() {
108 return HEADLESS.isSet(this) || Boolean.getBoolean("headless");
109 }
110
111 public boolean headless(String arg) {
112 return headless() || "--headless".equals(arg);
113 }
114
115 private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "SLKM","Show (via OpWriter) Lowered Kernel Model");
116 public boolean showLoweredKernelModel() {
117 return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
118 }
119 private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES","SCP", "Show HAT compilation phases");
120 public boolean showCompilationPhases() {
121 return SHOW_COMPILATION_PHASES.isSet(this);
122 }
123 private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL","PCK", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
124
125 public boolean profileCUDAKernel() {
126 return PROFILE_CUDA_KERNEL.isSet(this);
127 }
128 private static final Bit SHOW_COMPUTE_MODEL_JAVA_CODE = Bit.nextBit(PROFILE_CUDA_KERNEL, "SHOW_COMPUTE_MODEL_JAVA_CODE","SCMJC", "Show java code view of compute model");
129
130 public boolean showComputeModelJavaCode() {
131 return SHOW_COMPUTE_MODEL_JAVA_CODE.isSet(this);
132 }
133
134 public static final Bit CHECK_SSA_LOWERING = Bit.nextBit(SHOW_COMPUTE_MODEL_JAVA_CODE, "CHECK_SSA_LOWERING","CSSAL", "Verify that code model can be lowered to SSA");
135 public boolean checkSSALowering() {
136 return CHECK_SSA_LOWERING.isSet(this);
137 }
138
139 public static final List<Bit> bitList = List.of(
140 PLATFORM,
141 DEVICE,
142 MINIMIZE_COPIES,
143 TRACE,
144 PROFILE,
145 SHOW_CODE,
146 SHOW_KERNEL_MODEL,
147 SHOW_COMPUTE_MODEL,
148 SHOW_DEVICE_INFO,
149 INFO,
150 WARN,
151 UNIT,
152 TRACE_COPIES,
153 TRACE_SKIPPED_COPIES,
154 TRACE_ENQUEUES,
155 TRACE_CALLS,
156 SHOW_WHY,
157 SHOW_STATE,
158 PTX,
159 INTERPRET,
160 HEADLESS,
161 SHOW_LOWERED_KERNEL_MODEL,
162 SHOW_COMPILATION_PHASES,
163 PROFILE_CUDA_KERNEL,
164 SHOW_COMPUTE_MODEL_JAVA_CODE,
165 CHECK_SSA_LOWERING
166 );
167
168 private final int bits;
169
170 public int bits() {
171 return bits;
172 }
173
174 Config(int bits) {
175 this.bits = bits;
176 }
177
178 // These must sync with hat/backends/ffi/shared/include/config.h
179 // We can create the above config by running main() below...
180
181 public static Config fromEnvOrProperty() {
182 if (System.getenv("HAT") instanceof String opts) {
183 return fromSpec(opts);
184 } else if (System.getProperty("HAT") instanceof String opts) {
185 return fromSpec(opts);
186 } else {
187 return fromSpec("");
188 }
189 }
190
191 public static Config fromIntBits(int bits) {
192 return new Config(bits);
193 }
194
195 record BitValue(Bit bit, int value){}
196
197 public static Config fromSpec(String spec) {
198 Optional<Config> returnValue = Optional.of(Config.fromIntBits(0));
199 if (spec == null || spec.isEmpty()) {
200 // default is good
201 } else if (spec.equals("HELP")) {
202 bitList.forEach(bit -> IO.println(bit.name + "/" + bit.alt + " ".repeat(40 - (bit.name.length() + bit.alt.length())) + (bit.index < 10 ? " " : "") + bit.index + ((bit.size > 1) ? "-" + (bit.index + bit.size - 1) : " ") + " " + bit.description));
203 if (spec.equals("HELP")) {
204 System.exit(0);
205 }
206 } else if (spec.contains(",")) {
207 returnValue = Arrays.stream(spec.split(",")).map(Config::fromSpec).reduce((lhs, rhs) -> Config.fromIntBits(lhs.bits() | rhs.bits()));
208 } else if (!spec.contains(":")) {
209 returnValue = bitList.stream().filter(bit -> bit.name().equals(spec) || bit.alt().equals(spec)).findFirst().map(b -> fromIntBits(b.mask()));
210 } else {
211 var split = spec.split(":");
212 if (split.length == 2) {
213 var optBit = bitList.stream().filter(bit -> bit.name().equals(split[0]) || bit.alt().equals(split[0])).findFirst();
214 if (optBit.isPresent()) {
215 var bv = new BitValue(optBit.get(), Integer.parseInt(split[1]));
216 var bitz = bv.value << bv.bit().index();
217 returnValue = Optional.of(fromIntBits(bitz));
218 }
219 }
220 }
221 if (returnValue.isPresent()) {
222 return returnValue.get();
223 } else {
224 IO.println("Unexpected spec '" + spec + "'");
225 System.exit(1);
226 }
227 return null;
228 }
229
230 @Override
231 public String toString() {
232 StringBuilder builder = new StringBuilder();
233 for (Bit bit:bitList){
234 if (bit.isBitSet(bits)) {
235 if (!builder.isEmpty()) {
236 builder.append("|");
237 }
238 builder.append(bit.name());
239 }
240 }
241 return builder.toString();
242 }
243
244 public static void main() {
245 bitList.forEach(b-> System.out.printf("%30s MASK= %32s\n", b.name, b.maskString()));
246 }
247 }