1 /*
  2  * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.incubator.code.bytecode.impl;
 26 
 27 import java.lang.classfile.Attributes;
 28 import java.lang.classfile.ClassTransform;
 29 import java.lang.classfile.CodeModel;
 30 import java.lang.classfile.Label;
 31 import java.lang.classfile.MethodModel;
 32 import java.lang.classfile.TypeKind;
 33 import java.lang.classfile.attribute.StackMapFrameInfo;
 34 import java.lang.classfile.instruction.BranchInstruction;
 35 import java.lang.classfile.instruction.IncrementInstruction;
 36 import java.lang.classfile.instruction.LabelTarget;
 37 import java.lang.classfile.instruction.LoadInstruction;
 38 import java.lang.classfile.instruction.LookupSwitchInstruction;
 39 import java.lang.classfile.instruction.TableSwitchInstruction;
 40 import java.lang.classfile.instruction.StoreInstruction;
 41 import java.lang.constant.ClassDesc;
 42 import java.lang.reflect.AccessFlag;
 43 import java.util.ArrayList;
 44 import java.util.BitSet;
 45 import java.util.IdentityHashMap;
 46 import java.util.List;
 47 import java.util.Map;
 48 import java.util.stream.Collectors;
 49 
 50 import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*;
 51 import static java.lang.constant.ConstantDescs.CD_double;
 52 import static java.lang.constant.ConstantDescs.CD_long;
 53 
 54 /**
 55  * LocalsCompactor transforms class to reduce allocation of local slots in the Code attribute (max_locals).
 56  * It collects slot maps, compacts them and transforms the Code attribute accordingly.
 57  * <p>
 58  * Example of maps before compaction (max_locals = 13):
 59  * <pre>
 60  *  slots:  0   1   2   3   4   5   6   7   8   9   10  11  12  13
 61  *  ---------------------------------------------------------------
 62  *  bci 0:  *   *
 63  *      8:      *   *   *
 64  *     10:      *   *   *
 65  *     15:      *   *   *   *   *
 66  *     17:      *   *   *   *   *
 67  *     18:      *           *   *
 68  *     25:      *                   *   *
 69  *     27:      *                   *   *
 70  *     32:      *                   *   *   *   *
 71  *     34:      *                   *   *   *   *
 72  *     36:      *                           *   *
 73  *     43:      *                                   *   *
 74  *     45:      *                                   *   *
 75  *     50:                                          *   *   *   *
 76  *     52:                                          *   *   *   *
 77  *     54:                                                  *   *
 78  * </pre>
 79  * Compact form of the same maps (max_locals = 5):
 80  * <pre>
 81  *  slots:   0   1   2   3   4   5
 82  *         +12 +13  +6  +7  +8  +9
 83  *                 +10 +11
 84  *  -------------------------------
 85  *  bci 0:  *   *
 86  *      8:      *   *   *
 87  *     10:      *   *   *
 88  *     15:      *   *   *   *   *
 89  *     17:      *   *   *   *   *
 90  *     18:      *           *   *
 91  *     25:      *   *   *
 92  *     27:      *   *   *
 93  *     32:      *   *   *   *   *
 94  *     34:      *   *   *   *   *
 95  *     36:      *           *   *
 96  *     43:      *   *   *
 97  *     45:      *   *   *
 98  *     50:  *   *   *   *
 99  *     52:  *   *   *   *
100  *     54:  *   *
101  * </pre>
102  */
103 public final class LocalsCompactor {
104 
105     /**
106      * LocalsCompactor transformation requires complete class file binary
107      * @return LocalsCompactor class transform instance
108      */
109     public static ClassTransform instance() {
110         return (clb,cle) -> {
111             if (cle instanceof MethodModel mm) {
112                 clb.transformMethod(mm, (mb, me) -> {
113                     if (me instanceof CodeModel com) {
114                         int[] slotMap = new LocalsCompactor(com, countParamSlots(mm)).slotMap;
115                         mb.transformCode(com, (cob, coe) -> {
116                             switch (coe) {
117                                 case LoadInstruction li ->
118                                     cob.loadLocal(li.typeKind(), slotMap[li.slot()]);
119                                 case StoreInstruction si ->
120                                     cob.storeLocal(si.typeKind(), slotMap[si.slot()]);
121                                 case IncrementInstruction ii ->
122                                     cob.iinc(slotMap[ii.slot()], ii.constant());
123                                 default ->
124                                     cob.with(coe);
125                             }
126                         });
127                     } else {
128                         mb.with(me);
129                     }
130                 });
131             } else {
132                 clb.with(cle);
133             }
134         };
135     }
136 
137     private static int countParamSlots(MethodModel mm) {
138         int slots = mm.flags().has(AccessFlag.STATIC) ? 0 : 1;
139         for (ClassDesc p : mm.methodTypeSymbol().parameterList()) {
140             slots += p == CD_long || p == CD_double ? 2 : 1;
141         }
142         return slots;
143     }
144 
145     static final class Slot {
146         final BitSet map = new BitSet(); // Liveness map of the slot
147         int flags; // 0 - single slot, 1 - first of double slots, 2 - second of double slots, 3 - mixed
148     }
149 
150     private final List<Slot> maps; // Intermediate slots liveness maps
151     private final Map<Label, List<StackMapFrameInfo.VerificationTypeInfo>> frames;
152     private final int[] slotMap; // Output mapping of the slots
153     private final Map<Label, Integer> labelPc; // Labels map
154 
155     private LocalsCompactor(CodeModel com, int fixedSlots) {
156         frames = com.findAttribute(Attributes.stackMapTable()).map(
157                 smta -> smta.entries().stream().collect(
158                         Collectors.toMap(StackMapFrameInfo::target, StackMapFrameInfo::locals)))
159                 .orElse(Map.of());
160         var exceptionHandlers = com.exceptionHandlers();
161         maps = new ArrayList<>();
162         labelPc = new IdentityHashMap<>();
163         int pc = 0;
164         // Initialization of fixed slots
165         for (int slot = 0; slot < fixedSlots; slot++) {
166             getMap(slot).map.set(0);
167         }
168         // Filling the slots liveness maps
169         for (var e : com) {
170             switch(e) {
171                 case LabelTarget lt -> {
172                     labelPc.put(lt.label(), pc);
173                     for (var eh : exceptionHandlers) {
174                         if (eh.tryStart() == lt.label()) {
175                             mergeFrom(pc, eh.handler());
176                         }
177                     }
178                 }
179                 case LoadInstruction li ->
180                     load(pc, li.slot(), li.typeKind());
181                 case StoreInstruction si ->
182                     store(pc, si.slot(), si.typeKind());
183                 case IncrementInstruction ii ->
184                     loadSingle(pc, ii.slot());
185                 case BranchInstruction bi ->
186                     mergeFrom(pc, bi.target());
187                 case LookupSwitchInstruction si -> {
188                     mergeFrom(pc, si.defaultTarget());
189                     for (var sc : si.cases()) {
190                         mergeFrom(pc, sc.target());
191                     }
192                 }
193                 case TableSwitchInstruction si -> {
194                     mergeFrom(pc, si.defaultTarget());
195                     for (var sc : si.cases()) {
196                         mergeFrom(pc, sc.target());
197                     }
198                 }
199                 default -> pc--;
200             }
201             pc++;
202         }
203         // Merge locals in exception handlers
204         for (var eh : exceptionHandlers) {
205             Integer start = labelPc.get(eh.tryStart());
206             Integer end = labelPc.get(eh.tryEnd());
207             if (start != null && end != null) {
208                 var locals = frames.get(eh.handler());
209                 if (locals != null) {
210                     int slot = 0;
211                     for (var vti : locals) {
212                         if (vti == LONG || vti == DOUBLE) {
213                             markRange(slot, start, end + 1).flags |= 1;
214                             markRange(slot + 1, start, end + 1).flags |= 2;
215                             slot += 2;
216                         } else {
217                             if (vti != TOP) {
218                                 markRange(slot, start, end + 1);
219                             }
220                             slot++;
221                         }
222                     }
223                 }
224             }
225         }
226         // Initialization of slots mapping
227         slotMap = new int[maps.size()];
228         for (int slot = 0; slot < slotMap.length; slot++) {
229             slotMap[slot] = slot;
230         }
231         // Iterative compaction of slots
232         for (int targetSlot = 0; targetSlot < maps.size() - 1; targetSlot++) {
233             for (int sourceSlot = Math.max(targetSlot + 1, fixedSlots); sourceSlot < maps.size(); sourceSlot++) {
234                 Slot source = maps.get(sourceSlot);
235                 // Re-mapping single slot
236                 if (source.flags == 0) {
237                     Slot target = maps.get(targetSlot);
238                     if (!target.map.intersects(source.map)) {
239                         // Single re-mapping, merge of the liveness maps and shift of the following slots by 1 left
240                         target.map.or(source.map);
241                         maps.remove(sourceSlot);
242                         for (int slot = 0; slot < slotMap.length; slot++) {
243                             if (slotMap[slot] == sourceSlot) {
244                                 slotMap[slot] = targetSlot;
245                             } else if (slotMap[slot] > sourceSlot) {
246                                 slotMap[slot]--;
247                             }
248                         }
249                     }
250                 } else if (source.flags == 1 && sourceSlot > targetSlot + 1) {
251                     Slot source2 = maps.get(sourceSlot + 1);
252                     // Re-mapping distinct double slot
253                     if (source2.flags == 2) {
254                         Slot target = maps.get(targetSlot);
255                         Slot target2 = maps.get(targetSlot + 1);
256                         if (!target.map.intersects(source.map) && !target2.map.intersects(source2.map)) {
257                             // Double re-mapping, merge of the liveness maps and shift of the following slots by 2 left
258                             target.map.or(source.map);
259                             target2.map.or(source2.map);
260                             maps.remove(sourceSlot + 1);
261                             maps.remove(sourceSlot);
262                             for (int slot = 0; slot < slotMap.length; slot++) {
263                                 if (slotMap[slot] == sourceSlot) {
264                                     slotMap[slot] = targetSlot;
265                                 } else if (slotMap[slot] == sourceSlot + 1) {
266                                     slotMap[slot] = targetSlot + 1;
267                                 } else if (slotMap[slot] > sourceSlot + 1) {
268                                     slotMap[slot] -= 2;
269                                 }
270                             }
271                         }
272                     }
273                 }
274             }
275         }
276     }
277 
278     private Slot markRange(int slot, int from, int to) {
279         Slot s = getMap(slot);
280         s.map.set(from, Math.max(from + 1, to));
281         return s;
282     }
283 
284     private Slot getMap(int slot) {
285         while (slot >= maps.size()) {
286             maps.add(new Slot());
287         }
288         return maps.get(slot);
289     }
290 
291     private Slot loadSingle(int pc, int slot) {
292         Slot s =  getMap(slot);
293         int start = s.map.nextSetBit(0) + 1;
294         s.map.set(start, pc + 1);
295         return s;
296     }
297 
298     private void load(int pc, int slot, TypeKind tk) {
299         load(pc, slot, tk.slotSize() == 2);
300     }
301 
302     private void load(int pc, int slot, boolean dual) {
303         if (dual) {
304             loadSingle(pc, slot).flags |= 1;
305             loadSingle(pc, slot + 1).flags |= 2;
306         } else {
307             loadSingle(pc, slot);
308         }
309     }
310 
311     private void mergeFrom(int pc, Label target) {
312         int slot = 0;
313         for (var vti : frames.get(target)) {
314             if (vti != TOP) {
315                 if (vti == LONG || vti == DOUBLE) {
316                     load(pc, slot++, true);
317                 } else {
318                     loadSingle(pc, slot);
319                 }
320             }
321             slot++;
322         }
323     }
324 
325     private Slot storeSingle(int pc, int slot) {
326         Slot s = getMap(slot);
327         s.map.set(pc);
328         return s;
329     }
330 
331     private void store(int pc, int slot, TypeKind tk) {
332         if (tk.slotSize() == 2) {
333             storeSingle(pc, slot).flags |= 1;
334             storeSingle(pc, slot + 1).flags |= 2;
335         } else {
336             storeSingle(pc, slot);
337         }
338     }
339 }