1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat;
 27 
 28 import java.util.Arrays;
 29 import java.util.List;
 30 import java.util.Optional;
 31 
 32 public class Config {
 33 
 34     public record Bit(int index, int size, String name, String alt, String description) implements Comparable<Bit> {
 35         static Bit of(int index, int size, String name, String alt, String description){
 36             return new Bit(index,size,name, alt, description);
 37         }
 38         public static Bit of(int index, String name, String alt,  String description){
 39             return new Bit(index,1,name, alt,description);
 40         }
 41 
 42         public static Bit nextBit(Bit bit, int size, String name, String alt,String description){
 43             return new Bit(bit.index+bit.size,size,name, alt, description);
 44         }
 45         public static Bit nextBit(Bit bit, String name, String alt, String description){
 46             return nextBit(bit, 1,name,alt, description);
 47         }
 48         @Override
 49         public int compareTo(Bit bit) {
 50             return Integer.compare(index, bit.index);
 51         }
 52 
 53        public boolean isBitSet(int bits){
 54             return (mask()&bits) == mask();
 55         }
 56         public boolean isSet(Config config){
 57             return (mask()&config.bits) == mask();
 58         }
 59         public int mask(){
 60             return ((1<<size)-1) << index;
 61         }
 62 
 63         public String maskString(){
 64             return Integer.toBinaryString(mask());
 65         }
 66     }
 67 
 68     public static final Bit PLATFORM =  Bit.of(0,4, "PLATFORM","P", "FFI ONLY platform id (0-15)");
 69     public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","D","FFI ONLY device id (0-15)");
 70     private static final Bit MINIMIZE_COPIES =  Bit.nextBit(DEVICE, "MINIMIZE_COPIES","MC","FFI ONLY Try to minimize copies");
 71     public boolean minimizeCopies() {
 72         return MINIMIZE_COPIES.isSet(this);
 73     }
 74     private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "T","FFI ONLY trace code");
 75     private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "P","FFI ONLY Turn on profiling");
 76     private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","SC","Show generated code (PTX/OpenCL/CUDA)");
 77     public boolean showCode() {
 78         return SHOW_CODE.isSet(this);
 79     }
 80     private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "SKM","Show (via OpWriter) Kernel Model");
 81     public boolean showKernelModel() {
 82         return SHOW_KERNEL_MODEL.isSet(this);
 83     }
 84     private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "SCM","Show (via OpWriter) Compute Model");
 85     public boolean showComputeModel() {
 86         return SHOW_COMPUTE_MODEL.isSet(this);
 87     }
 88     private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "SDI","FFI show platform and device info");
 89     public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "I","INFO level logging");
 90     public static final Bit WARN = Bit.nextBit(INFO, "WARN", "W","WARN(ing) level logging ");
 91     public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "U","UNIT test level logging  ");
 92     private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "TC","FFI ONLY trace copies");
 93     private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES,"TRACE_SKIPPED_COPIES", "TSC",  "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
 94     private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES","TE", "FFI ONLY trace enqueued tasks");
 95     private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "TCALLS","FFI ONLY trace calls (enter/leave)");
 96     private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "SW","FFI ONLY show why we decided to copy buffer (H to D)");
 97     private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "SS","Show iface buffer state changes");
 98     public boolean showState(){return SHOW_STATE.isSet(this);}
 99     private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "PTX","FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
100     public boolean ptx(){return PTX.isSet(this);}
101     private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "I","Interpret the code model rather than converting to bytecode");
102     public boolean interpret() {
103         return INTERPRET.isSet(this);
104     }
105     private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "H","Don't show UI");
106 
107     public boolean headless() {
108         return HEADLESS.isSet(this) || Boolean.getBoolean("headless");
109     }
110 
111     public boolean headless(String arg) {
112         return headless() || "--headless".equals(arg);
113     }
114 
115     private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "SLKM","Show (via OpWriter) Lowered Kernel Model");
116     public boolean showLoweredKernelModel() {
117         return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
118     }
119     private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES","SCP", "Show HAT compilation phases");
120     public boolean showCompilationPhases() {
121         return SHOW_COMPILATION_PHASES.isSet(this);
122     }
123     private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL","PCK", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
124 
125     public boolean profileCUDAKernel() {
126         return PROFILE_CUDA_KERNEL.isSet(this);
127     }
128     private static final Bit SHOW_COMPUTE_MODEL_JAVA_CODE = Bit.nextBit(PROFILE_CUDA_KERNEL, "SHOW_COMPUTE_MODEL_JAVA_CODE","SCMJC", "Show java code view of compute model");
129 
130     public boolean showComputeModelJavaCode() {
131         return SHOW_COMPUTE_MODEL_JAVA_CODE.isSet(this);
132     }
133 
134     public static final Bit CHECK_SSA_LOWERING = Bit.nextBit(SHOW_COMPUTE_MODEL_JAVA_CODE, "CHECK_SSA_LOWERING","CSSAL", "Verify that code model can be lowered to SSA");
135     public boolean checkSSALowering() {
136         return CHECK_SSA_LOWERING.isSet(this);
137     }
138 
139     public static final List<Bit> bitList = List.of(
140             PLATFORM,
141             DEVICE,
142             MINIMIZE_COPIES,
143             TRACE,
144             PROFILE,
145             SHOW_CODE,
146             SHOW_KERNEL_MODEL,
147             SHOW_COMPUTE_MODEL,
148             SHOW_DEVICE_INFO,
149             INFO,
150             WARN,
151             UNIT,
152             TRACE_COPIES,
153             TRACE_SKIPPED_COPIES,
154             TRACE_ENQUEUES,
155             TRACE_CALLS,
156             SHOW_WHY,
157             SHOW_STATE,
158             PTX,
159             INTERPRET,
160             HEADLESS,
161             SHOW_LOWERED_KERNEL_MODEL,
162             SHOW_COMPILATION_PHASES,
163             PROFILE_CUDA_KERNEL,
164             SHOW_COMPUTE_MODEL_JAVA_CODE,
165             CHECK_SSA_LOWERING
166     );
167 
168     private final int bits;
169 
170     public int bits() {
171         return bits;
172     }
173 
174     Config(int bits) {
175         this.bits = bits;
176     }
177 
178     // These must sync with hat/backends/ffi/shared/include/config.h
179     // We can create the above config by running main() below...
180 
181     public static Config fromEnvOrProperty() {
182         if (System.getenv("HAT") instanceof String opts) {
183             return fromSpec(opts);
184         } else if (System.getProperty("HAT") instanceof String opts) {
185             return fromSpec(opts);
186         } else {
187             return fromSpec("");
188         }
189     }
190 
191     public static Config fromIntBits(int bits) {
192         return new Config(bits);
193     }
194 
195     record BitValue(Bit bit, int value){}
196 
197     public static Config fromSpec(String spec) {
198         Optional<Config> returnValue = Optional.of(Config.fromIntBits(0));
199         if (spec == null || spec.isEmpty()) {
200             // default is good
201         } else if (spec.equals("HELP")) {
202             bitList.forEach(bit -> IO.println(bit.name + "/" + bit.alt + " ".repeat(40 - (bit.name.length() + bit.alt.length())) + (bit.index < 10 ? " " : "") + bit.index + ((bit.size > 1) ? "-" + (bit.index + bit.size - 1) : "  ") + " " + bit.description));
203             if (spec.equals("HELP")) {
204                 System.exit(0);
205             }
206         } else if (spec.contains(",")) {
207             returnValue = Arrays.stream(spec.split(",")).map(Config::fromSpec).reduce((lhs, rhs) -> Config.fromIntBits(lhs.bits() | rhs.bits()));
208         } else if (!spec.contains(":")) {
209             returnValue = bitList.stream().filter(bit -> bit.name().equals(spec) || bit.alt().equals(spec)).findFirst().map(b -> fromIntBits(b.mask()));
210         } else {
211             var split = spec.split(":");
212             if (split.length == 2) {
213                 var optBit = bitList.stream().filter(bit -> bit.name().equals(split[0]) || bit.alt().equals(split[0])).findFirst();
214                 if (optBit.isPresent()) {
215                     var bv = new BitValue(optBit.get(), Integer.parseInt(split[1]));
216                     var bitz = bv.value << bv.bit().index();
217                     returnValue = Optional.of(fromIntBits(bitz));
218                 }
219             }
220         }
221         if (returnValue.isPresent()) {
222             return returnValue.get();
223         } else {
224             IO.println("Unexpected spec '" + spec + "'");
225             System.exit(1);
226         }
227         return null;
228     }
229 
230     @Override
231     public String toString() {
232         StringBuilder builder = new StringBuilder();
233         for (Bit bit:bitList){
234             if (bit.isBitSet(bits)) {
235                 if (!builder.isEmpty()) {
236                     builder.append("|");
237                 }
238                 builder.append(bit.name());
239             }
240         }
241         return builder.toString();
242     }
243 
244     public static void main() {
245        bitList.forEach(b-> System.out.printf("%30s MASK= %32s\n",  b.name, b.maskString()));
246     }
247 }