1 /*
  2  * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.incubator.code.bytecode.impl;
 26 
 27 import java.lang.classfile.Attributes;
 28 import java.lang.classfile.ClassFile;
 29 import java.lang.classfile.CodeModel;
 30 import java.lang.classfile.Label;
 31 import java.lang.classfile.MethodModel;
 32 import java.lang.classfile.TypeKind;
 33 import java.lang.classfile.attribute.StackMapFrameInfo;
 34 import java.lang.classfile.instruction.BranchInstruction;
 35 import java.lang.classfile.instruction.IncrementInstruction;
 36 import java.lang.classfile.instruction.LabelTarget;
 37 import java.lang.classfile.instruction.LoadInstruction;
 38 import java.lang.classfile.instruction.LookupSwitchInstruction;
 39 import java.lang.classfile.instruction.TableSwitchInstruction;
 40 import java.lang.classfile.instruction.StoreInstruction;
 41 import java.lang.constant.ClassDesc;
 42 import java.lang.reflect.AccessFlag;
 43 import java.util.ArrayList;
 44 import java.util.BitSet;
 45 import java.util.List;
 46 import java.util.Map;
 47 import java.util.stream.Collectors;
 48 
 49 import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*;
 50 import static java.lang.constant.ConstantDescs.CD_double;
 51 import static java.lang.constant.ConstantDescs.CD_long;
 52 
 53 /**
 54  * LocalsCompactor transforms class to reduce allocation of local slots in the Code attribute (max_locals).
 55  * It collects slot maps, compacts them and transforms the Code attribute accordingly.
 56  * <p>
 57  * Example of maps before compaction (max_locals = 13):
 58  * <pre>
 59  *  slots:  0   1   2   3   4   5   6   7   8   9   10  11  12  13
 60  *  ---------------------------------------------------------------
 61  *  bci 0:  *   *
 62  *      8:      *   *   *
 63  *     10:      *   *   *
 64  *     15:      *   *   *   *   *
 65  *     17:      *   *   *   *   *
 66  *     18:      *           *   *
 67  *     25:      *                   *   *
 68  *     27:      *                   *   *
 69  *     32:      *                   *   *   *   *
 70  *     34:      *                   *   *   *   *
 71  *     36:      *                           *   *
 72  *     43:      *                                   *   *
 73  *     45:      *                                   *   *
 74  *     50:                                          *   *   *   *
 75  *     52:                                          *   *   *   *
 76  *     54:                                                  *   *
 77  * </pre>
 78  * Compact form of the same maps (max_locals = 5):
 79  * <pre>
 80  *  slots:   0   1   2   3   4   5
 81  *         +12 +13  +6  +7  +8  +9
 82  *                 +10 +11
 83  *  -------------------------------
 84  *  bci 0:  *   *
 85  *      8:      *   *   *
 86  *     10:      *   *   *
 87  *     15:      *   *   *   *   *
 88  *     17:      *   *   *   *   *
 89  *     18:      *           *   *
 90  *     25:      *   *   *
 91  *     27:      *   *   *
 92  *     32:      *   *   *   *   *
 93  *     34:      *   *   *   *   *
 94  *     36:      *           *   *
 95  *     43:      *   *   *
 96  *     45:      *   *   *
 97  *     50:  *   *   *   *
 98  *     52:  *   *   *   *
 99  *     54:  *   *
100  * </pre>
101  */
102 public final class LocalsCompactor {
103 
104     /**
105      * LocalsCompactor transformation requires complete class file binary
106      * @param classBytes class file binary to transform
107      * @return transformed class file binary
108      */
109     public static byte[] transform(byte[] classBytes) {
110         return ClassFile.of().transformClass(ClassFile.of().parse(classBytes), (clb,cle) -> {
111             if (cle instanceof MethodModel mm) {
112                 clb.transformMethod(mm, (mb, me) -> {
113                     if (me instanceof CodeModel com) {
114                         int[] slotMap = new LocalsCompactor(com, countParamSlots(mm)).slotMap;
115                         mb.transformCode(com, (cob, coe) -> {
116                             switch (coe) {
117                                 case LoadInstruction li ->
118                                     cob.loadLocal(li.typeKind(), slotMap[li.slot()]);
119                                 case StoreInstruction si ->
120                                     cob.storeLocal(si.typeKind(), slotMap[si.slot()]);
121                                 case IncrementInstruction ii ->
122                                     cob.iinc(slotMap[ii.slot()], ii.constant());
123                                 default ->
124                                     cob.with(coe);
125                             }
126                         });
127                     } else {
128                         mb.with(me);
129                     }
130                 });
131             } else {
132                 clb.with(cle);
133             }
134         });
135     }
136 
137     private static int countParamSlots(MethodModel mm) {
138         int slots = mm.flags().has(AccessFlag.STATIC) ? 0 : 1;
139         for (ClassDesc p : mm.methodTypeSymbol().parameterList()) {
140             slots += p == CD_long || p == CD_double ? 2 : 1;
141         }
142         return slots;
143     }
144 
145     static final class Slot {
146         final BitSet map = new BitSet(); // Liveness map of the slot
147         int flags; // 0 - single slot, 1 - first of double slots, 2 - second of double slots, 3 - mixed
148     }
149 
150     private final List<Slot> maps; // Intermediate slots liveness maps
151     private final Map<Label, List<StackMapFrameInfo.VerificationTypeInfo>> frames;
152     private final int[] slotMap; // Output mapping of the slots
153 
154     private LocalsCompactor(CodeModel com, int fixedSlots) {
155         frames = com.findAttribute(Attributes.stackMapTable()).map(
156                 smta -> smta.entries().stream().collect(
157                         Collectors.toMap(StackMapFrameInfo::target, StackMapFrameInfo::locals)))
158                 .orElse(Map.of());
159         var exceptionHandlers = com.exceptionHandlers();
160         maps = new ArrayList<>();
161         int pc = 0;
162         // Initialization of fixed slots
163         for (int slot = 0; slot < fixedSlots; slot++) {
164             getMap(slot).map.set(0);
165         }
166         // Filling the slots liveness maps
167         for (var e : com) {
168             switch(e) {
169                 case LabelTarget lt -> {
170                     for (var eh : exceptionHandlers) {
171                         if (eh.tryStart() == lt.label()) {
172                             mergeFrom(pc, eh.handler());
173                         }
174                     }
175                 }
176                 case LoadInstruction li ->
177                     load(pc, li.slot(), li.typeKind());
178                 case StoreInstruction si ->
179                     store(pc, si.slot(), si.typeKind());
180                 case IncrementInstruction ii ->
181                     loadSingle(pc, ii.slot());
182                 case BranchInstruction bi ->
183                     mergeFrom(pc, bi.target());
184                 case LookupSwitchInstruction si -> {
185                     mergeFrom(pc, si.defaultTarget());
186                     for (var sc : si.cases()) {
187                         mergeFrom(pc, sc.target());
188                     }
189                 }
190                 case TableSwitchInstruction si -> {
191                     mergeFrom(pc, si.defaultTarget());
192                     for (var sc : si.cases()) {
193                         mergeFrom(pc, sc.target());
194                     }
195                 }
196                 default -> pc--;
197             }
198             pc++;
199         }
200         // Initialization of slots mapping
201         slotMap = new int[maps.size()];
202         for (int slot = 0; slot < slotMap.length; slot++) {
203             slotMap[slot] = slot;
204         }
205         // Iterative compaction of slots
206         for (int targetSlot = 0; targetSlot < maps.size() - 1; targetSlot++) {
207             for (int sourceSlot = Math.max(targetSlot + 1, fixedSlots); sourceSlot < maps.size(); sourceSlot++) {
208                 Slot source = maps.get(sourceSlot);
209                 // Re-mapping single slot
210                 if (source.flags == 0) {
211                     Slot target = maps.get(targetSlot);
212                     if (!target.map.intersects(source.map)) {
213                         // Single re-mapping, merge of the liveness maps and shift of the following slots by 1 left
214                         target.map.or(source.map);
215                         maps.remove(sourceSlot);
216                         for (int slot = 0; slot < slotMap.length; slot++) {
217                             if (slotMap[slot] == sourceSlot) {
218                                 slotMap[slot] = targetSlot;
219                             } else if (slotMap[slot] > sourceSlot) {
220                                 slotMap[slot]--;
221                             }
222                         }
223                     }
224                 } else if (source.flags == 1 && sourceSlot > targetSlot + 1) {
225                     Slot source2 = maps.get(sourceSlot + 1);
226                     // Re-mapping distinct double slot
227                     if (source2.flags == 2) {
228                         Slot target = maps.get(targetSlot);
229                         Slot target2 = maps.get(targetSlot + 1);
230                         if (!target.map.intersects(source.map) && !target2.map.intersects(source2.map)) {
231                             // Double re-mapping, merge of the liveness maps and shift of the following slots by 2 left
232                             target.map.or(source.map);
233                             target2.map.or(source2.map);
234                             maps.remove(sourceSlot + 1);
235                             maps.remove(sourceSlot);
236                             for (int slot = 0; slot < slotMap.length; slot++) {
237                                 if (slotMap[slot] == sourceSlot) {
238                                     slotMap[slot] = targetSlot;
239                                 } else if (slotMap[slot] == sourceSlot + 1) {
240                                     slotMap[slot] = targetSlot + 1;
241                                 } else if (slotMap[slot] > sourceSlot + 1) {
242                                     slotMap[slot] -= 2;
243                                 }
244                             }
245                         }
246                     }
247                 }
248             }
249         }
250     }
251 
252     private Slot getMap(int slot) {
253         while (slot >= maps.size()) {
254             maps.add(new Slot());
255         }
256         return maps.get(slot);
257     }
258 
259     private Slot loadSingle(int pc, int slot) {
260         Slot s =  getMap(slot);
261         int start = s.map.nextSetBit(0) + 1;
262         s.map.set(start, pc + 1);
263         return s;
264     }
265 
266     private void load(int pc, int slot, TypeKind tk) {
267         load(pc, slot, tk.slotSize() == 2);
268     }
269 
270     private void load(int pc, int slot, boolean dual) {
271         if (dual) {
272             loadSingle(pc, slot).flags |= 1;
273             loadSingle(pc, slot + 1).flags |= 2;
274         } else {
275             loadSingle(pc, slot);
276         }
277     }
278 
279     private void mergeFrom(int pc, Label target) {
280         int slot = 0;
281         for (var vti : frames.get(target)) {
282             if (vti != TOP) {
283                 if (vti == LONG || vti == DOUBLE) {
284                     load(pc, slot++, true);
285                 } else {
286                     loadSingle(pc, slot);
287                 }
288             }
289             slot++;
290         }
291     }
292 
293     private Slot storeSingle(int pc, int slot) {
294         Slot s = getMap(slot);
295         s.map.set(pc);
296         return s;
297     }
298 
299     private void store(int pc, int slot, TypeKind tk) {
300         if (tk.slotSize() == 2) {
301             storeSingle(pc, slot).flags |= 1;
302             storeSingle(pc, slot + 1).flags |= 2;
303         } else {
304             storeSingle(pc, slot);
305         }
306     }
307 }