1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package hat;
 27 
 28 import java.util.List;
 29 
 30 public class Config {
 31 
 32     public record Bit(int index, int size, String name, String description) implements Comparable<Bit> {
 33         static Bit of(int index, int size, String name, String description){
 34             return new Bit(index,size,name,description);
 35         }
 36         public static Bit of(int index, String name, String description){
 37             return new Bit(index,1,name,description);
 38         }
 39 
 40         public static Bit nextBit(Bit bit, int size, String name, String description){
 41             return new Bit(bit.index+bit.size,size,name, description);
 42         }
 43         public static Bit nextBit(Bit bit, String name, String description){
 44             return nextBit(bit, 1,name,description);
 45         }
 46         @Override
 47         public int compareTo(Bit bit) {
 48             return Integer.compare(index, bit.index);
 49         }
 50 
 51        public boolean isBitSet(int bits){
 52             return (mask()&bits) == mask();
 53         }
 54         public boolean isSet(Config config){
 55             return (mask()&config.bits) == mask();
 56         }
 57         public int mask(){
 58             return ((1<<size)-1) << index;
 59         }
 60 
 61         public String maskString(){
 62             return Integer.toBinaryString(mask());
 63         }
 64     }
 65 
 66     public static final Bit PLATFORM =  Bit.of(0,4, "PLATFORM", "FFI ONLY platform id (0-15)");
 67     public static final Bit DEVICE = Bit.nextBit(PLATFORM, 4, "DEVICE","FFI ONLY device id (0-15)");
 68     private static final Bit MINIMIZE_COPIES =  Bit.nextBit(DEVICE, "MINIMIZE_COPIES","FFI ONLY Try to minimize copies");
 69     public boolean minimizeCopies() {
 70         return MINIMIZE_COPIES.isSet(this);
 71     }
 72     private static final Bit TRACE = Bit.nextBit(MINIMIZE_COPIES,"TRACE", "FFI ONLY trace code");
 73     private static final Bit PROFILE = Bit.nextBit(TRACE, "PROFILE", "FFI ONLY Turn on profiling");
 74     private static final Bit SHOW_CODE = Bit.nextBit(PROFILE,"SHOW_CODE","Show generated code (PTX/OpenCL/CUDA)");
 75     public boolean showCode() {
 76         return SHOW_CODE.isSet(this);
 77     }
 78     private static final Bit SHOW_KERNEL_MODEL = Bit.nextBit(SHOW_CODE,"SHOW_KERNEL_MODEL", "Show (via OpWriter) Kernel Model");
 79     public boolean showKernelModel() {
 80         return SHOW_COMPUTE_MODEL.isSet(this);
 81     }
 82     private static final Bit SHOW_COMPUTE_MODEL = Bit.nextBit(SHOW_KERNEL_MODEL,"SHOW_COMPUTE_MODEL", "Show (via OpWriter) Compute Model");
 83     public boolean showComputeModel() {
 84         return SHOW_COMPUTE_MODEL.isSet(this);
 85     }
 86     private static final Bit SHOW_DEVICE_INFO = Bit.nextBit(SHOW_COMPUTE_MODEL, "SHOW_DEVICE_INFO", "FFI show platform and device info");
 87     public static final Bit INFO = Bit.nextBit(SHOW_DEVICE_INFO, "INFO", "INFO level logging");
 88     public static final Bit WARN = Bit.nextBit(INFO, "WARN", "WARN(ing) level logging ");
 89     public static final Bit UNIT = Bit.nextBit(WARN, "UNIT", "UNIT test level logging  ");
 90     private static final Bit TRACE_COPIES = Bit.nextBit(UNIT, "TRACE_COPIES", "FFI ONLY trace copies");
 91     private static final Bit TRACE_SKIPPED_COPIES = Bit.nextBit(TRACE_COPIES, "TRACE_SKIPPED_COPIES", "FFI ONLY Trace skipped copies (see MINIMIZE_COPIES) ");
 92     private static final Bit TRACE_ENQUEUES = Bit.nextBit(TRACE_SKIPPED_COPIES,"TRACE_ENQUEUES", "FFI ONLY trace enqueued tasks");
 93     private static final Bit TRACE_CALLS= Bit.nextBit(TRACE_ENQUEUES, "TRACE_CALLS", "FFI ONLY trace calls (enter/leave)");
 94     private static final Bit SHOW_WHY = Bit.nextBit(TRACE_CALLS, "SHOW_WHY", "FFI ONLY show why we decided to copy buffer (H to D)");
 95     private static final Bit SHOW_STATE = Bit.nextBit(SHOW_WHY, "SHOW_STATE", "Show iface buffer state changes");
 96     public boolean showState(){return SHOW_STATE.isSet(this);}
 97     private static final Bit PTX = Bit.nextBit(SHOW_STATE, "PTX", "FFI (NVIDIA) ONLY pass PTX rather than C99 CUDA code");
 98     public boolean ptx(){return PTX.isSet(this);}
 99     private static final Bit INTERPRET = Bit.nextBit(PTX, "INTERPRET", "Interpret the code model rather than converting to bytecode");
100     public boolean interpret() {
101         return INTERPRET.isSet(this);
102     }
103     private static final Bit HEADLESS = Bit.nextBit(INTERPRET, "HEADLESS", "Don't show UI");
104     public boolean headless() {
105         return HEADLESS.isSet(this)|| Boolean.getBoolean("headless");
106     }
107     public boolean headless(String arg) {
108         return headless()|"--headless".equals(arg);
109     }
110     private static final Bit SHOW_LOWERED_KERNEL_MODEL = Bit.nextBit(HEADLESS,"SHOW_LOWERED_KERNEL_MODEL", "Show (via OpWriter) Lowered Kernel Model");
111     public boolean showLoweredKernelModel() {
112         return SHOW_LOWERED_KERNEL_MODEL.isSet(this);
113     }
114     private static final Bit SHOW_COMPILATION_PHASES = Bit.nextBit(SHOW_LOWERED_KERNEL_MODEL, "SHOW_COMPILATION_PHASES", "Show HAT compilation phases");
115     public boolean showCompilationPhases() {
116         return SHOW_COMPILATION_PHASES.isSet(this);
117     }
118     private static final Bit PROFILE_CUDA_KERNEL = Bit.nextBit(SHOW_COMPILATION_PHASES, "PROFILE_CUDA_KERNEL", "Add -lineinfo to CUDA kernel compilation for profiling and debugging");
119 
120     public boolean isProfileCUDAKernel() {
121         return PROFILE_CUDA_KERNEL.isSet(this);
122     }
123     public static final List<Bit> bitList = List.of(
124             PLATFORM,
125             DEVICE,
126             MINIMIZE_COPIES,
127             TRACE,
128             PROFILE,
129             SHOW_CODE,
130             SHOW_KERNEL_MODEL,
131             SHOW_COMPUTE_MODEL,
132             SHOW_DEVICE_INFO,
133             INFO,
134             WARN,
135             UNIT,
136             TRACE_COPIES,
137             TRACE_SKIPPED_COPIES,
138             TRACE_ENQUEUES,
139             TRACE_CALLS,
140             SHOW_WHY,
141             SHOW_STATE,
142             PTX,
143             INTERPRET,
144             HEADLESS,
145             SHOW_LOWERED_KERNEL_MODEL,
146             SHOW_COMPILATION_PHASES,
147             PROFILE_CUDA_KERNEL
148     );
149 
150     private int bits;
151 
152 
153     public int bits(){
154         return bits;
155     }
156     public void bits(int bits){
157         this.bits = bits;
158     }
159 
160     Config(int bits){
161         bits(bits);
162     }
163 
164     // These must sync with hat/backends/ffi/shared/include/config.h
165     // We can create the above config by running main() below...
166 
167     public static Config fromEnvOrProperty() {
168         if (System.getenv("HAT") instanceof String opts) {
169             System.out.println("From env " + opts);
170             return fromSpec(opts);
171         }
172         if (System.getProperty("HAT") instanceof String opts) {
173             System.out.println("From prop " + opts);
174             return fromSpec(opts);
175         }
176         return fromSpec("");
177     }
178 
179     public static Config fromIntBits(int bits) {
180         return new Config(bits);
181     }
182 
183     public static Config fromBits(List<Bit> configBits) {
184         int allBits = 0;
185         for (Bit configBit : configBits) {
186             allBits |= configBit.mask();
187         }
188         return new Config(allBits);
189     }
190 
191     public static Config fromBits(Bit... configBits) {
192         return fromBits(List.of(configBits));
193     }
194 
195     public Config and(Bit... configBits) {
196         return Config.fromIntBits(Config.fromBits(List.of(configBits)).bits & bits);
197     }
198 
199     public Config or(Bit... configBits) {
200         return Config.fromIntBits(Config.fromBits(List.of(configBits)).bits | bits);
201     }
202 
203     public record BitValue(Bit bit, int value){}
204 
205     public static Config fromSpec(String spec) {
206         if (spec == null || spec.equals("")) {
207             return Config.fromIntBits(0);
208         }
209         for (Bit bit:bitList) {
210             if (bit.name().equals(spec)) {
211                 return new Config(bit.mask());
212             }
213         }
214         if (spec.contains(",")) {
215             var bits = 0;
216             for (var opt: spec.split(",")) {
217                 var split = opt.split(":");
218                 var valName=split[0];
219                 var value=split.length==1?1:Integer.parseInt(split[1]);
220                 var bitValue = Config.bitList.stream()
221                         .filter(bit ->bit.name().equals(valName))
222                         .map(bit -> new BitValue(bit, value))
223                         .findFirst()
224                         .orElseThrow();
225                 bits |= bitValue.value << bitValue.bit.index();
226             }
227             return fromIntBits(bits);
228         } else {
229             System.out.println("Unexpected spec '" + spec + "'");
230             System.exit(1);
231             return Config.fromIntBits(0);
232         }
233     }
234 
235     @Override
236     public String toString() {
237         StringBuilder builder = new StringBuilder();
238         for (Bit bit:bitList){
239             if (bit.isBitSet(bits)) {
240                 if (!builder.isEmpty()) {
241                     builder.append("|");
242                 }
243                 builder.append(bit.name());
244 
245             }
246         }
247         return builder.toString();
248     }
249 
250     public static void main(String[] args){
251        bitList.forEach(b-> {
252            System.out.printf("%30s MASK= %32s\n",  b.name,b.maskString());
253        });
254 
255     }
256 
257 }