1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat;
 27 
 28 import java.util.Arrays;
 29 import java.util.List;
 30 import java.util.Optional;
 31 
 32 public class Config {
 33 
 34     public record Bit(int index, int size, String name, String alt, String description) implements Comparable<Bit> {
 35         static Bit of(int index, int size, String name, String alt, String description){
 36             return new Bit(index,size,name, alt, description);
 37         }
 38         public static Bit of(int index, String name, String alt,  String description){
 39             return new Bit(index,1,name, alt,description);
 40         }
 41 
 42         public static Bit nextBit(Bit bit, int size, String name, String alt,String description){
 43             return new Bit(bit.index+bit.size,size,name, alt, description);
 44         }
 45         public static Bit nextBit(Bit bit, String name, String alt, String description){
 46             return nextBit(bit, 1,name,alt, description);
 47         }
 48         @Override
 49         public int compareTo(Bit bit) {
 50             return Integer.compare(index, bit.index);
 51         }
 52 
 53        public boolean isBitSet(int bits){
 54             return (mask()&bits) == mask();
 55         }
 56         public boolean isSet(Config config){
 57             return (mask()&config.bits) == mask();
 58         }
 59         public int mask(){
 60             return ((1<<size)-1) << index;
 61         }
 62 
 63         public String maskString(){
 64             return Integer.toBinaryString(mask());
 65         }
 66     }
 67 
 68     public static final Bit PLATFORM =  Bit.of(0,4, "PLATFORM","P", "FFI ONLY platform id (0-15)");
 69     public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","D","FFI ONLY device id (0-15)");
 70     private static final Bit MINIMIZE_COPIES =  Bit.nextBit(DEVICE, "MINIMIZE_COPIES","MC","FFI ONLY Try to minimize copies");
 71     public boolean minimizeCopies() {
 72         return MINIMIZE_COPIES.isSet(this);
 73     }
 74     private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "T","FFI ONLY trace code");
 75     private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "P","FFI ONLY Turn on profiling");
 76     private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","SC","Show generated code (PTX/OpenCL/CUDA)");
 77     public boolean showCode() {
 78         return SHOW_CODE.isSet(this);
 79     }
 80     private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "SKM","Show (via OpWriter) Kernel Model");
 81     public boolean showKernelModel() {
 82         return SHOW_COMPUTE_MODEL.isSet(this);
 83     }
 84     private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "SCM","Show (via OpWriter) Compute Model");
 85     public boolean showComputeModel() {
 86         return SHOW_COMPUTE_MODEL.isSet(this);
 87     }
 88     private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "SDI","FFI show platform and device info");
 89     public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "I","INFO level logging");
 90     public static final Bit WARN = Bit.nextBit(INFO, "WARN", "W","WARN(ing) level logging ");
 91     public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "U","UNIT test level logging  ");
 92     private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "TC","FFI ONLY trace copies");
 93     private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES,"TRACE_SKIPPED_COPIES", "TSC",  "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
 94     private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES","TE", "FFI ONLY trace enqueued tasks");
 95     private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "TCALLS","FFI ONLY trace calls (enter/leave)");
 96     private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "SW","FFI ONLY show why we decided to copy buffer (H to D)");
 97     private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "SS","Show iface buffer state changes");
 98     public boolean showState(){return SHOW_STATE.isSet(this);}
 99     private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "PTX","FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
100     public boolean ptx(){return PTX.isSet(this);}
101     private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "I","Interpret the code model rather than converting to bytecode");
102     public boolean interpret() {
103         return INTERPRET.isSet(this);
104     }
105     private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "H","Don't show UI");
106     public boolean headless() {
107         return HEADLESS.isSet(this)|| Boolean.getBoolean("headless");
108     }
109     public boolean headless(String arg) {
110         return headless()|"--headless".equals(arg);
111     }
112     private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "SLKM","Show (via OpWriter) Lowered Kernel Model");
113     public boolean showLoweredKernelModel() {
114         return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
115     }
116     private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES","SCP", "Show HAT compilation phases");
117     public boolean showCompilationPhases() {
118         return SHOW_COMPILATION_PHASES.isSet(this);
119     }
120     private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL","PCK", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
121 
122     public boolean profileCUDAKernel() {
123         return PROFILE_CUDA_KERNEL.isSet(this);
124     }
125     private static final Bit SHOW_COMPUTE_MODEL_JAVA_CODE = Bit.nextBit(PROFILE_CUDA_KERNEL, "SHOW_COMPUTE_MODEL_JAVA_CODE","SCMJC", "Show java code view of compute model");
126 
127     public boolean showComputeModelJavaCode() {
128         return SHOW_COMPUTE_MODEL_JAVA_CODE.isSet(this);
129     }
130     public static final List<Bit> bitList = List.of(
131             PLATFORM,
132             DEVICE,
133             MINIMIZE_COPIES,
134             TRACE,
135             PROFILE,
136             SHOW_CODE,
137             SHOW_KERNEL_MODEL,
138             SHOW_COMPUTE_MODEL,
139             SHOW_DEVICE_INFO,
140             INFO,
141             WARN,
142             UNIT,
143             TRACE_COPIES,
144             TRACE_SKIPPED_COPIES,
145             TRACE_ENQUEUES,
146             TRACE_CALLS,
147             SHOW_WHY,
148             SHOW_STATE,
149             PTX,
150             INTERPRET,
151             HEADLESS,
152             SHOW_LOWERED_KERNEL_MODEL,
153             SHOW_COMPILATION_PHASES,
154             PROFILE_CUDA_KERNEL,
155             SHOW_COMPUTE_MODEL_JAVA_CODE
156     );
157 
158     final private int bits;
159 
160     public int bits(){
161         return bits;
162     }
163 
164     Config(int bits, boolean junk){
165         this.bits = bits;
166     }
167 
168     // These must sync with hat/backends/ffi/shared/include/config.h
169     // We can create the above config by running main() below...
170 
171     public static Config fromEnvOrProperty() {
172         if (System.getenv("HAT") instanceof String opts) {
173            // System.out.println("From env " + opts);
174             return fromSpec(opts);
175         }else if (System.getProperty("HAT") instanceof String opts) {
176            // System.out.println("From prop " + opts);
177             return fromSpec(opts);
178         }else {
179             return fromSpec("");
180         }
181     }
182 
183     public static Config fromIntBits(int bits) {
184         return new Config(bits,false);
185     }
186 
187      record BitValue(Bit bit, int value){}
188 
189     // recursive!!!
190     public static Config fromSpec(String spec) {
191         Optional<Config> returnValue=Optional.of(Config.fromIntBits(0));
192         if (spec == null || spec.isEmpty()) {
193            // default is good
194         } else if (spec.equals("HELP")) {
195           bitList.forEach(bit -> System.out.println(bit.name+ "/"+bit.alt+" ".repeat(40-(bit.name.length()+bit.alt.length()))+(bit.index<10?" ":"")+bit.index+((bit.size>1)?"-"+(bit.index+bit.size-1):"  ")+ " "+bit.description));
196           if(spec.equals("HELP")){
197               System.exit(0);
198           }
199         }else if (spec.contains(",")) {
200             returnValue =  Arrays.stream(spec.split(",")).map(Config::fromSpec).reduce((lhs, rhs) -> Config.fromIntBits(lhs.bits() | rhs.bits()));
201         } else if (!spec.contains(":")) {
202             returnValue=  bitList.stream().filter(bit->bit.name().equals(spec)||bit.alt().equals(spec)).findFirst().map(b->fromIntBits(b.mask()));
203         }else{
204             var split = spec.split(":");
205             if (split.length==2) {
206                 var optBit = bitList.stream().filter(bit -> bit.name().equals(split[0])||bit.alt().equals(split[0])).findFirst();
207                 if (optBit.isPresent()) {
208                     var bv = new BitValue(optBit.get(), Integer.parseInt(split[1]));
209                     var bitz = bv.value << bv.bit().index();
210                     returnValue= Optional.of(fromIntBits(bitz));
211                 }
212             }
213         }
214         if (returnValue.isPresent()) {
215             return returnValue.get();
216         }else {
217             System.out.println("Unexpected spec '" + spec + "'");
218             System.exit(1);
219         }
220         return null;
221     }
222 
223     @Override
224     public String toString() {
225         StringBuilder builder = new StringBuilder();
226         for (Bit bit:bitList){
227             if (bit.isBitSet(bits)) {
228                 if (!builder.isEmpty()) {
229                     builder.append("|");
230                 }
231                 builder.append(bit.name());
232 
233             }
234         }
235         return builder.toString();
236     }
237 
238     public static void main(String[] args){
239        bitList.forEach(b-> {
240            System.out.printf("%30s MASK= %32s\n",  b.name,b.maskString());
241        });
242 
243     }
244 
245 }