1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package hat;
27
28 import java.util.Arrays;
29 import java.util.List;
30 import java.util.Optional;
31
32 public class Config {
33
34 public record Bit(int index, int size, String name, String alt, String description) implements Comparable<Bit> {
35 static Bit of(int index, int size, String name, String alt, String description){
36 return new Bit(index,size,name, alt, description);
37 }
38 public static Bit of(int index, String name, String alt, String description){
39 return new Bit(index,1,name, alt,description);
40 }
41
42 public static Bit nextBit(Bit bit, int size, String name, String alt,String description){
43 return new Bit(bit.index+bit.size,size,name, alt, description);
44 }
45 public static Bit nextBit(Bit bit, String name, String alt, String description){
46 return nextBit(bit, 1,name,alt, description);
47 }
48 @Override
49 public int compareTo(Bit bit) {
50 return Integer.compare(index, bit.index);
51 }
52
53 public boolean isBitSet(int bits){
54 return (mask()&bits) == mask();
55 }
56 public boolean isSet(Config config){
57 return (mask()&config.bits) == mask();
58 }
59 public int mask(){
60 return ((1<<size)-1) << index;
61 }
62
63 public String maskString(){
64 return Integer.toBinaryString(mask());
65 }
66 }
67
68 public static final Bit PLATFORM = Bit.of(0,4, "PLATFORM","P", "FFI ONLY platform id (0-15)");
69 public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","D","FFI ONLY device id (0-15)");
70 private static final Bit MINIMIZE_COPIES = Bit.nextBit(DEVICE, "MINIMIZE_COPIES","MC","FFI ONLY Try to minimize copies");
71 public boolean minimizeCopies() {
72 return MINIMIZE_COPIES.isSet(this);
73 }
74 private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "T","FFI ONLY trace code");
75 private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "P","FFI ONLY Turn on profiling");
76 private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","SC","Show generated code (PTX/OpenCL/CUDA)");
77 public boolean showCode() {
78 return SHOW_CODE.isSet(this);
79 }
80 private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "SKM","Show (via OpWriter) Kernel Model");
81 public boolean showKernelModel() {
82 return SHOW_COMPUTE_MODEL.isSet(this);
83 }
84 private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "SCM","Show (via OpWriter) Compute Model");
85 public boolean showComputeModel() {
86 return SHOW_COMPUTE_MODEL.isSet(this);
87 }
88 private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "SDI","FFI show platform and device info");
89 public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "I","INFO level logging");
90 public static final Bit WARN = Bit.nextBit(INFO, "WARN", "W","WARN(ing) level logging ");
91 public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "U","UNIT test level logging ");
92 private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "TC","FFI ONLY trace copies");
93 private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES,"TRACE_SKIPPED_COPIES", "TSC", "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
94 private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES","TE", "FFI ONLY trace enqueued tasks");
95 private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "TCALLS","FFI ONLY trace calls (enter/leave)");
96 private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "SW","FFI ONLY show why we decided to copy buffer (H to D)");
97 private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "SS","Show iface buffer state changes");
98 public boolean showState(){return SHOW_STATE.isSet(this);}
99 private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "PTX","FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
100 public boolean ptx(){return PTX.isSet(this);}
101 private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "I","Interpret the code model rather than converting to bytecode");
102 public boolean interpret() {
103 return INTERPRET.isSet(this);
104 }
105 private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "H","Don't show UI");
106 public boolean headless() {
107 return HEADLESS.isSet(this)|| Boolean.getBoolean("headless");
108 }
109 public boolean headless(String arg) {
110 return headless()|"--headless".equals(arg);
111 }
112 private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "SLKM","Show (via OpWriter) Lowered Kernel Model");
113 public boolean showLoweredKernelModel() {
114 return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
115 }
116 private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES","SCP", "Show HAT compilation phases");
117 public boolean showCompilationPhases() {
118 return SHOW_COMPILATION_PHASES.isSet(this);
119 }
120 private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL","PCK", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
121
122 public boolean profileCUDAKernel() {
123 return PROFILE_CUDA_KERNEL.isSet(this);
124 }
125 private static final Bit SHOW_COMPUTE_MODEL_JAVA_CODE = Bit.nextBit(PROFILE_CUDA_KERNEL, "SHOW_COMPUTE_MODEL_JAVA_CODE","SCMJC", "Show java code view of compute model");
126
127 public boolean showComputeModelJavaCode() {
128 return SHOW_COMPUTE_MODEL_JAVA_CODE.isSet(this);
129 }
130 public static final List<Bit> bitList = List.of(
131 PLATFORM,
132 DEVICE,
133 MINIMIZE_COPIES,
134 TRACE,
135 PROFILE,
136 SHOW_CODE,
137 SHOW_KERNEL_MODEL,
138 SHOW_COMPUTE_MODEL,
139 SHOW_DEVICE_INFO,
140 INFO,
141 WARN,
142 UNIT,
143 TRACE_COPIES,
144 TRACE_SKIPPED_COPIES,
145 TRACE_ENQUEUES,
146 TRACE_CALLS,
147 SHOW_WHY,
148 SHOW_STATE,
149 PTX,
150 INTERPRET,
151 HEADLESS,
152 SHOW_LOWERED_KERNEL_MODEL,
153 SHOW_COMPILATION_PHASES,
154 PROFILE_CUDA_KERNEL,
155 SHOW_COMPUTE_MODEL_JAVA_CODE
156 );
157
158 final private int bits;
159
160 public int bits(){
161 return bits;
162 }
163
164 Config(int bits, boolean junk){
165 this.bits = bits;
166 }
167
168 // These must sync with hat/backends/ffi/shared/include/config.h
169 // We can create the above config by running main() below...
170
171 public static Config fromEnvOrProperty() {
172 if (System.getenv("HAT") instanceof String opts) {
173 // System.out.println("From env " + opts);
174 return fromSpec(opts);
175 }else if (System.getProperty("HAT") instanceof String opts) {
176 // System.out.println("From prop " + opts);
177 return fromSpec(opts);
178 }else {
179 return fromSpec("");
180 }
181 }
182
183 public static Config fromIntBits(int bits) {
184 return new Config(bits,false);
185 }
186
187 record BitValue(Bit bit, int value){}
188
189 // recursive!!!
190 public static Config fromSpec(String spec) {
191 Optional<Config> returnValue=Optional.of(Config.fromIntBits(0));
192 if (spec == null || spec.isEmpty()) {
193 // default is good
194 } else if (spec.equals("HELP")) {
195 bitList.forEach(bit -> System.out.println(bit.name+ "/"+bit.alt+" ".repeat(40-(bit.name.length()+bit.alt.length()))+(bit.index<10?" ":"")+bit.index+((bit.size>1)?"-"+(bit.index+bit.size-1):" ")+ " "+bit.description));
196 if(spec.equals("HELP")){
197 System.exit(0);
198 }
199 }else if (spec.contains(",")) {
200 returnValue = Arrays.stream(spec.split(",")).map(Config::fromSpec).reduce((lhs, rhs) -> Config.fromIntBits(lhs.bits() | rhs.bits()));
201 } else if (!spec.contains(":")) {
202 returnValue= bitList.stream().filter(bit->bit.name().equals(spec)||bit.alt().equals(spec)).findFirst().map(b->fromIntBits(b.mask()));
203 }else{
204 var split = spec.split(":");
205 if (split.length==2) {
206 var optBit = bitList.stream().filter(bit -> bit.name().equals(split[0])||bit.alt().equals(split[0])).findFirst();
207 if (optBit.isPresent()) {
208 var bv = new BitValue(optBit.get(), Integer.parseInt(split[1]));
209 var bitz = bv.value << bv.bit().index();
210 returnValue= Optional.of(fromIntBits(bitz));
211 }
212 }
213 }
214 if (returnValue.isPresent()) {
215 return returnValue.get();
216 }else {
217 System.out.println("Unexpected spec '" + spec + "'");
218 System.exit(1);
219 }
220 return null;
221 }
222
223 @Override
224 public String toString() {
225 StringBuilder builder = new StringBuilder();
226 for (Bit bit:bitList){
227 if (bit.isBitSet(bits)) {
228 if (!builder.isEmpty()) {
229 builder.append("|");
230 }
231 builder.append(bit.name());
232
233 }
234 }
235 return builder.toString();
236 }
237
238 public static void main(String[] args){
239 bitList.forEach(b-> {
240 System.out.printf("%30s MASK= %32s\n", b.name,b.maskString());
241 });
242
243 }
244
245 }