1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat;
27
28 import java.util.List;
29
30 public class Config {
31
32 public record Bit(int index, int size, String name, String description) implements Comparable<Bit> {
33 static Bit of(int index, int size, String name, String description){
34 return new Bit(index,size,name,description);
35 }
36 public static Bit of(int index, String name, String description){
37 return new Bit(index,1,name,description);
38 }
39
40 public static Bit nextBit(Bit bit, int size, String name, String description){
41 return new Bit(bit.index+bit.size,size,name, description);
42 }
43 public static Bit nextBit(Bit bit, String name, String description){
44 return nextBit(bit, 1,name,description);
45 }
46 @Override
47 public int compareTo(Bit bit) {
48 return Integer.compare(index, bit.index);
49 }
50
51 public boolean isBitSet(int bits){
52 return (mask()&bits) == mask();
53 }
54 public boolean isSet(Config config){
55 return (mask()&config.bits) == mask();
56 }
57 public int mask(){
58 return ((1<<size)-1) << index;
59 }
60
61 public String maskString(){
62 return Integer.toBinaryString(mask());
63 }
64 }
65
66 public static final Bit PLATFORM = Bit.of(0,4, "PLATFORM", "FFI ONLY platform id (0-15)");
67 public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","FFI ONLY device id (0-15)");
68 private static final Bit MINIMIZE_COPIES = Bit.nextBit(DEVICE, "MINIMIZE_COPIES","FFI ONLY Try to minimize copies");
69 public boolean minimizeCopies() {
70 return MINIMIZE_COPIES.isSet(this);
71 }
72 private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "FFI ONLY trace code");
73 private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "FFI ONLY Turn on profiling");
74 private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","Show generated code (PTX/OpenCL/CUDA)");
75 public boolean showCode() {
76 return SHOW_CODE.isSet(this);
77 }
78 private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "Show (via OpWriter) Kernel Model");
79 public boolean showKernelModel() {
80 return SHOW_COMPUTE_MODEL.isSet(this);
81 }
82 private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "Show (via OpWriter) Compute Model");
83 public boolean showComputeModel() {
84 return SHOW_COMPUTE_MODEL.isSet(this);
85 }
86 private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "FFI show platform and device info");
87 public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "INFO level logging");
88 public static final Bit WARN = Bit.nextBit(INFO, "WARN", "WARN(ing) level logging ");
89 public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "UNIT test level logging ");
90 private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "FFI ONLY trace copies");
91 private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES, "TRACE_SKIPPED_COPIES", "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
92 private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES", "FFI ONLY trace enqueued tasks");
93 private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "FFI ONLY trace calls (enter/leave)");
94 private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "FFI ONLY show why we decided to copy buffer (H to D)");
95 private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "Show iface buffer state changes");
96 public boolean showState(){return SHOW_STATE.isSet(this);}
97 private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
98 public boolean ptx(){return PTX.isSet(this);}
99 private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "Interpret the code model rather than converting to bytecode");
100 public boolean interpret() {
101 return INTERPRET.isSet(this);
102 }
103 private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "Don't show UI");
104 public boolean headless() {
105 return HEADLESS.isSet(this)|| Boolean.getBoolean("headless");
106 }
107 public boolean headless(String arg) {
108 return headless()|"--headless".equals(arg);
109 }
110 private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "Show (via OpWriter) Lowered Kernel Model");
111 public boolean showLoweredKernelModel() {
112 return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
113 }
114 private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES", "Show HAT compilation phases");
115 public boolean showCompilationPhases() {
116 return SHOW_COMPILATION_PHASES.isSet(this);
117 }
118 private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
119
120 public boolean isProfileCUDAKernel() {
121 return PROFILE_CUDA_KERNEL.isSet(this);
122 }
123 public static final List<Bit> bitList = List.of(
124 PLATFORM,
125 DEVICE,
126 MINIMIZE_COPIES,
127 TRACE,
128 PROFILE,
129 SHOW_CODE,
130 SHOW_KERNEL_MODEL,
131 SHOW_COMPUTE_MODEL,
132 SHOW_DEVICE_INFO,
133 INFO,
134 WARN,
135 UNIT,
136 TRACE_COPIES,
137 TRACE_SKIPPED_COPIES,
138 TRACE_ENQUEUES,
139 TRACE_CALLS,
140 SHOW_WHY,
141 SHOW_STATE,
142 PTX,
143 INTERPRET,
144 HEADLESS,
145 SHOW_LOWERED_KERNEL_MODEL,
146 SHOW_COMPILATION_PHASES,
147 PROFILE_CUDA_KERNEL
148 );
149
150 private int bits;
151
152
153 public int bits(){
154 return bits;
155 }
156 public void bits(int bits){
157 this.bits = bits;
158 }
159
160 Config(int bits){
161 bits(bits);
162 }
163
164 // These must sync with hat/backends/ffi/shared/include/config.h
165 // We can create the above config by running main() below...
166
167 public static Config fromEnvOrProperty() {
168 if (System.getenv("HAT") instanceof String opts) {
169 System.out.println("From env " + opts);
170 return fromSpec(opts);
171 }
172 if (System.getProperty("HAT") instanceof String opts) {
173 System.out.println("From prop " + opts);
174 return fromSpec(opts);
175 }
176 return fromSpec("");
177 }
178
179 public static Config fromIntBits(int bits) {
180 return new Config(bits);
181 }
182
183 public static Config fromBits(List<Bit> configBits) {
184 int allBits = 0;
185 for (Bit configBit : configBits) {
186 allBits |= configBit.mask();
187 }
188 return new Config(allBits);
189 }
190
191 public static Config fromBits(Bit... configBits) {
192 return fromBits(List.of(configBits));
193 }
194
195 public Config and(Bit... configBits) {
196 return Config.fromIntBits(Config.fromBits(List.of(configBits)).bits & bits);
197 }
198
199 public Config or(Bit... configBits) {
200 return Config.fromIntBits(Config.fromBits(List.of(configBits)).bits | bits);
201 }
202
203 public record BitValue(Bit bit, int value){}
204
205 public static Config fromSpec(String spec) {
206 if (spec == null || spec.equals("")) {
207 return Config.fromIntBits(0);
208 }
209 for (Bit bit:bitList) {
210 if (bit.name().equals(spec)) {
211 return new Config(bit.mask());
212 }
213 }
214 if (spec.contains(",")) {
215 var bits = 0;
216 for (var opt: spec.split(",")) {
217 var split = opt.split(":");
218 var valName=split[0];
219 var value=split.length==1?1:Integer.parseInt(split[1]);
220 var bitValue = Config.bitList.stream()
221 .filter(bit ->bit.name().equals(valName))
222 .map(bit -> new BitValue(bit, value))
223 .findFirst()
224 .orElseThrow();
225 bits |= bitValue.value << bitValue.bit.index();
226 }
227 return fromIntBits(bits);
228 } else {
229 System.out.println("Unexpected spec '" + spec + "'");
230 System.exit(1);
231 return Config.fromIntBits(0);
232 }
233 }
234
235 @Override
236 public String toString() {
237 StringBuilder builder = new StringBuilder();
238 for (Bit bit:bitList){
239 if (bit.isBitSet(bits)) {
240 if (!builder.isEmpty()) {
241 builder.append("|");
242 }
243 builder.append(bit.name());
244
245 }
246 }
247 return builder.toString();
248 }
249
250 public static void main(String[] args){
251 bitList.forEach(b-> {
252 System.out.printf("%30s MASK= %32s\n", b.name,b.maskString());
253 });
254
255 }
256
257 }