1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.lang.reflect.code.bytecode;
 26 
 27 import java.lang.classfile.Attributes;
 28 import java.lang.classfile.ClassTransform;
 29 import java.lang.classfile.CodeModel;
 30 import java.lang.classfile.Label;
 31 import java.lang.classfile.MethodModel;
 32 import java.lang.classfile.TypeKind;
 33 import java.lang.classfile.attribute.StackMapFrameInfo;
 34 import java.lang.classfile.instruction.BranchInstruction;
 35 import java.lang.classfile.instruction.IncrementInstruction;
 36 import java.lang.classfile.instruction.LabelTarget;
 37 import java.lang.classfile.instruction.LoadInstruction;
 38 import java.lang.classfile.instruction.LookupSwitchInstruction;
 39 import java.lang.classfile.instruction.TableSwitchInstruction;
 40 import java.lang.classfile.instruction.StoreInstruction;
 41 import java.lang.constant.ClassDesc;
 42 import java.lang.reflect.AccessFlag;
 43 import java.util.ArrayList;
 44 import java.util.BitSet;
 45 import java.util.List;
 46 import java.util.Map;
 47 import java.util.stream.Collectors;
 48 
 49 import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*;
 50 import static java.lang.constant.ConstantDescs.CD_double;
 51 import static java.lang.constant.ConstantDescs.CD_long;
 52 
 53 /**
 54  * LocalsCompactor transforms class to reduce allocation of local slots in the Code attribute (max_locals).
 55  * It collects slot maps, compacts them and transforms the Code attribute accordingly.
 56  * <p>
 57  * Example of maps before compaction (max_locals = 13):
 58  * <pre>
 59  *  slots:  0   1   2   3   4   5   6   7   8   9   10  11  12  13
 60  *  ---------------------------------------------------------------
 61  *  bci 0:  *   *
 62  *      8:      *   *   *
 63  *     10:      *   *   *
 64  *     15:      *   *   *   *   *
 65  *     17:      *   *   *   *   *
 66  *     18:      *           *   *
 67  *     25:      *                   *   *
 68  *     27:      *                   *   *
 69  *     32:      *                   *   *   *   *
 70  *     34:      *                   *   *   *   *
 71  *     36:      *                           *   *
 72  *     43:      *                                   *   *
 73  *     45:      *                                   *   *
 74  *     50:                                          *   *   *   *
 75  *     52:                                          *   *   *   *
 76  *     54:                                                  *   *
 77  * </pre>
 78  * Compact form of the same maps (max_locals = 5):
 79  * <pre>
 80  *  slots:   0   1   2   3   4   5
 81  *         +12 +13  +6  +7  +8  +9
 82  *                 +10 +11
 83  *  -------------------------------
 84  *  bci 0:  *   *
 85  *      8:      *   *   *
 86  *     10:      *   *   *
 87  *     15:      *   *   *   *   *
 88  *     17:      *   *   *   *   *
 89  *     18:      *           *   *
 90  *     25:      *   *   *
 91  *     27:      *   *   *
 92  *     32:      *   *   *   *   *
 93  *     34:      *   *   *   *   *
 94  *     36:      *           *   *
 95  *     43:      *   *   *
 96  *     45:      *   *   *
 97  *     50:  *   *   *   *
 98  *     52:  *   *   *   *
 99  *     54:  *   *
100  * </pre>
101  */
102 public final class LocalsCompactor {
103 
104     public static final ClassTransform INSTANCE = (clb,cle) -> {
105         if (cle instanceof MethodModel mm) {
106             clb.transformMethod(mm, (mb, me) -> {
107                 if (me instanceof CodeModel com) {
108                     int[] slotMap = new LocalsCompactor(com, countParamSlots(mm)).slotMap;
109                     mb.transformCode(com, (cob, coe) -> {
110                         switch (coe) {
111                             case LoadInstruction li ->
112                                 cob.loadLocal(li.typeKind(), slotMap[li.slot()]);
113                             case StoreInstruction si ->
114                                 cob.storeLocal(si.typeKind(), slotMap[si.slot()]);
115                             case IncrementInstruction ii ->
116                                 cob.iinc(slotMap[ii.slot()], ii.constant());
117                             default ->
118                                 cob.with(coe);
119                         }
120                     });
121                 } else {
122                     mb.with(me);
123                 }
124             });
125         } else {
126             clb.with(cle);
127         }
128     };
129 
130     private static int countParamSlots(MethodModel mm) {
131         int slots = mm.flags().has(AccessFlag.STATIC) ? 0 : 1;
132         for (ClassDesc p : mm.methodTypeSymbol().parameterList()) {
133             slots += p == CD_long || p == CD_double ? 2 : 1;
134         }
135         return slots;
136     }
137 
138     static final class Slot {
139         final BitSet map = new BitSet(); // Liveness map of the slot
140         int flags; // 0 - single slot, 1 - first of double slots, 2 - second of double slots, 3 - mixed
141     }
142 
143     private final List<Slot> maps; // Intermediate slots liveness maps
144     private final Map<Label, List<StackMapFrameInfo.VerificationTypeInfo>> frames;
145     private final int[] slotMap; // Output mapping of the slots
146 
147     private LocalsCompactor(CodeModel com, int fixedSlots) {
148         frames = com.findAttribute(Attributes.stackMapTable()).map(
149                 smta -> smta.entries().stream().collect(
150                         Collectors.toMap(StackMapFrameInfo::target, StackMapFrameInfo::locals)))
151                 .orElse(Map.of());
152         var exceptionHandlers = com.exceptionHandlers();
153         maps = new ArrayList<>();
154         int pc = 0;
155         // Initialization of fixed slots
156         for (int slot = 0; slot < fixedSlots; slot++) {
157             getMap(slot).map.set(0);
158         }
159         // Filling the slots liveness maps
160         for (var e : com) {
161             switch(e) {
162                 case LabelTarget lt -> {
163                     for (var eh : exceptionHandlers) {
164                         if (eh.tryStart() == lt.label()) {
165                             mergeFrom(pc, eh.handler());
166                         }
167                     }
168                 }
169                 case LoadInstruction li ->
170                     load(pc, li.slot(), li.typeKind());
171                 case StoreInstruction si ->
172                     store(pc, si.slot(), si.typeKind());
173                 case IncrementInstruction ii ->
174                     loadSingle(pc, ii.slot());
175                 case BranchInstruction bi ->
176                     mergeFrom(pc, bi.target());
177                 case LookupSwitchInstruction si -> {
178                     mergeFrom(pc, si.defaultTarget());
179                     for (var sc : si.cases()) {
180                         mergeFrom(pc, sc.target());
181                     }
182                 }
183                 case TableSwitchInstruction si -> {
184                     mergeFrom(pc, si.defaultTarget());
185                     for (var sc : si.cases()) {
186                         mergeFrom(pc, sc.target());
187                     }
188                 }
189                 default -> pc--;
190             }
191             pc++;
192         }
193         // Initialization of slots mapping
194         slotMap = new int[maps.size()];
195         for (int slot = 0; slot < slotMap.length; slot++) {
196             slotMap[slot] = slot;
197         }
198         // Iterative compaction of slots
199         for (int targetSlot = 0; targetSlot < maps.size() - 1; targetSlot++) {
200             for (int sourceSlot = Math.max(targetSlot + 1, fixedSlots); sourceSlot < maps.size(); sourceSlot++) {
201                 Slot source = maps.get(sourceSlot);
202                 // Re-mapping single slot
203                 if (source.flags == 0) {
204                     Slot target = maps.get(targetSlot);
205                     if (!target.map.intersects(source.map)) {
206                         // Single re-mapping, merge of the liveness maps and shift of the following slots by 1 left
207                         target.map.or(source.map);
208                         maps.remove(sourceSlot);
209                         for (int slot = 0; slot < slotMap.length; slot++) {
210                             if (slotMap[slot] == sourceSlot) {
211                                 slotMap[slot] = targetSlot;
212                             } else if (slotMap[slot] > sourceSlot) {
213                                 slotMap[slot]--;
214                             }
215                         }
216                     }
217                 } else if (source.flags == 1 && sourceSlot > targetSlot + 1) {
218                     Slot source2 = maps.get(sourceSlot + 1);
219                     // Re-mapping distinct double slot
220                     if (source2.flags == 2) {
221                         Slot target = maps.get(targetSlot);
222                         Slot target2 = maps.get(targetSlot + 1);
223                         if (!target.map.intersects(source.map) && !target2.map.intersects(source2.map)) {
224                             // Double re-mapping, merge of the liveness maps and shift of the following slots by 2 left
225                             target.map.or(source.map);
226                             target2.map.or(source2.map);
227                             maps.remove(sourceSlot + 1);
228                             maps.remove(sourceSlot);
229                             for (int slot = 0; slot < slotMap.length; slot++) {
230                                 if (slotMap[slot] == sourceSlot) {
231                                     slotMap[slot] = targetSlot;
232                                 } else if (slotMap[slot] == sourceSlot + 1) {
233                                     slotMap[slot] = targetSlot + 1;
234                                 } else if (slotMap[slot] > sourceSlot + 1) {
235                                     slotMap[slot] -= 2;
236                                 }
237                             }
238                         }
239                     }
240                 }
241             }
242         }
243     }
244 
245     private Slot getMap(int slot) {
246         while (slot >= maps.size()) {
247             maps.add(new Slot());
248         }
249         return maps.get(slot);
250     }
251 
252     private Slot loadSingle(int pc, int slot) {
253         Slot s =  getMap(slot);
254         int start = s.map.nextSetBit(0) + 1;
255         s.map.set(start, pc + 1);
256         return s;
257     }
258 
259     private void load(int pc, int slot, TypeKind tk) {
260         load(pc, slot, tk.slotSize() == 2);
261     }
262 
263     private void load(int pc, int slot, boolean dual) {
264         if (dual) {
265             loadSingle(pc, slot).flags |= 1;
266             loadSingle(pc, slot + 1).flags |= 2;
267         } else {
268             loadSingle(pc, slot);
269         }
270     }
271 
272     private void mergeFrom(int pc, Label target) {
273         int slot = 0;
274         for (var vti : frames.get(target)) {
275             if (vti != ITEM_TOP) {
276                 if (vti == ITEM_LONG || vti == ITEM_DOUBLE) {
277                     load(pc, slot++, true);
278                 } else {
279                     loadSingle(pc, slot);
280                 }
281             }
282             slot++;
283         }
284     }
285 
286     private Slot storeSingle(int pc, int slot) {
287         Slot s = getMap(slot);
288         s.map.set(pc);
289         return s;
290     }
291 
292     private void store(int pc, int slot, TypeKind tk) {
293         if (tk.slotSize() == 2) {
294             storeSingle(pc, slot).flags |= 1;
295             storeSingle(pc, slot + 1).flags |= 2;
296         } else {
297             storeSingle(pc, slot);
298         }
299     }
300 }