1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.incubator.code.bytecode;
 26 
 27 import java.lang.classfile.Attributes;
 28 import java.lang.classfile.ClassTransform;
 29 import java.lang.classfile.CodeBuilder;
 30 import java.lang.classfile.CodeElement;
 31 import java.lang.classfile.CodeModel;
 32 import java.lang.classfile.CodeTransform;
 33 import java.lang.classfile.Label;
 34 import java.lang.classfile.MethodModel;
 35 import java.lang.classfile.TypeKind;
 36 import java.lang.classfile.attribute.StackMapFrameInfo;
 37 import java.lang.classfile.instruction.BranchInstruction;
 38 import java.lang.classfile.instruction.ExceptionCatch;
 39 import java.lang.classfile.instruction.IncrementInstruction;
 40 import java.lang.classfile.instruction.LabelTarget;
 41 import java.lang.classfile.instruction.LoadInstruction;
 42 import java.lang.classfile.instruction.LookupSwitchInstruction;
 43 import java.lang.classfile.instruction.TableSwitchInstruction;
 44 import java.lang.classfile.instruction.StoreInstruction;
 45 import java.lang.constant.ClassDesc;
 46 import java.lang.reflect.AccessFlag;
 47 import java.util.ArrayList;
 48 import java.util.BitSet;
 49 import java.util.List;
 50 import java.util.Map;
 51 import java.util.stream.Collectors;
 52 
 53 import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*;
 54 import static java.lang.constant.ConstantDescs.CD_double;
 55 import static java.lang.constant.ConstantDescs.CD_long;
 56 
 57 /**
 58  * LocalsCompactor transforms class to reduce allocation of local slots in the Code attribute (max_locals).
 59  * It collects slot maps, compacts them and transforms the Code attribute accordingly.
 60  * <p>
 61  * Example of maps before compaction (max_locals = 13):
 62  * <pre>
 63  *  slots:  0   1   2   3   4   5   6   7   8   9   10  11  12  13
 64  *  ---------------------------------------------------------------
 65  *  bci 0:  *   *
 66  *      8:      *   *   *
 67  *     10:      *   *   *
 68  *     15:      *   *   *   *   *
 69  *     17:      *   *   *   *   *
 70  *     18:      *           *   *
 71  *     25:      *                   *   *
 72  *     27:      *                   *   *
 73  *     32:      *                   *   *   *   *
 74  *     34:      *                   *   *   *   *
 75  *     36:      *                           *   *
 76  *     43:      *                                   *   *
 77  *     45:      *                                   *   *
 78  *     50:                                          *   *   *   *
 79  *     52:                                          *   *   *   *
 80  *     54:                                                  *   *
 81  * </pre>
 82  * Compact form of the same maps (max_locals = 5):
 83  * <pre>
 84  *  slots:   0   1   2   3   4   5
 85  *         +12 +13  +6  +7  +8  +9
 86  *                 +10 +11
 87  *  -------------------------------
 88  *  bci 0:  *   *
 89  *      8:      *   *   *
 90  *     10:      *   *   *
 91  *     15:      *   *   *   *   *
 92  *     17:      *   *   *   *   *
 93  *     18:      *           *   *
 94  *     25:      *   *   *
 95  *     27:      *   *   *
 96  *     32:      *   *   *   *   *
 97  *     34:      *   *   *   *   *
 98  *     36:      *           *   *
 99  *     43:      *   *   *
100  *     45:      *   *   *
101  *     50:  *   *   *   *
102  *     52:  *   *   *   *
103  *     54:  *   *
104  * </pre>
105  */
106 public final class LocalsCompactor {
107 
108     static class ExceptionTableCompactor implements CodeTransform {
109         ExceptionCatch last = null;
110 
111         @Override
112         public void accept(CodeBuilder cob, CodeElement coe) {
113             if (coe instanceof ExceptionCatch ec) {
114                 if (ec.tryStart() != ec.tryEnd()) {
115                     if (last != null) {
116                         if (last.handler() == ec.handler() && last.catchType().equals(ec.catchType())) {
117                             if (last.tryStart() == ec.tryEnd()) {
118                                 last = ExceptionCatch.of(last.handler(), ec.tryStart(), last.tryEnd(), last.catchType());
119                                 return;
120                             } else if (last.tryEnd() == ec.tryStart()) {
121                                 last = ExceptionCatch.of(last.handler(), last.tryStart(), ec.tryEnd(), last.catchType());
122                                 return;
123                             }
124                         }
125                         cob.with(last);
126                     }
127                     last = ec;
128                 }
129             } else {
130                 cob.with(coe);
131             }
132         }
133 
134         @Override
135         public void atEnd(CodeBuilder cob) {
136             if (last != null) {
137                 cob.with(last);
138                 last = null;
139             }
140         }
141     }
142 
143     public static final ClassTransform INSTANCE = (clb,cle) -> {
144         if (cle instanceof MethodModel mm) {
145             clb.transformMethod(mm, (mb, me) -> {
146                 if (me instanceof CodeModel com) {
147                     int[] slotMap = new LocalsCompactor(com, countParamSlots(mm)).slotMap;
148                     // @@@ ExceptionTableCompactor can be chained on ClassTransform level when the recent Class-File API is merged into code-reflection
149                     mb.transformCode(com, new ExceptionTableCompactor().andThen((cob, coe) -> {
150                         switch (coe) {
151                             case LoadInstruction li ->
152                                 cob.loadLocal(li.typeKind(), slotMap[li.slot()]);
153                             case StoreInstruction si ->
154                                 cob.storeLocal(si.typeKind(), slotMap[si.slot()]);
155                             case IncrementInstruction ii ->
156                                 cob.iinc(slotMap[ii.slot()], ii.constant());
157                             default ->
158                                 cob.with(coe);
159                         }
160                     }));
161                 } else {
162                     mb.with(me);
163                 }
164             });
165         } else {
166             clb.with(cle);
167         }
168     };
169 
170     private static int countParamSlots(MethodModel mm) {
171         int slots = mm.flags().has(AccessFlag.STATIC) ? 0 : 1;
172         for (ClassDesc p : mm.methodTypeSymbol().parameterList()) {
173             slots += p == CD_long || p == CD_double ? 2 : 1;
174         }
175         return slots;
176     }
177 
178     static final class Slot {
179         final BitSet map = new BitSet(); // Liveness map of the slot
180         int flags; // 0 - single slot, 1 - first of double slots, 2 - second of double slots, 3 - mixed
181     }
182 
183     private final List<Slot> maps; // Intermediate slots liveness maps
184     private final Map<Label, List<StackMapFrameInfo.VerificationTypeInfo>> frames;
185     private final int[] slotMap; // Output mapping of the slots
186 
187     private LocalsCompactor(CodeModel com, int fixedSlots) {
188         frames = com.findAttribute(Attributes.stackMapTable()).map(
189                 smta -> smta.entries().stream().collect(
190                         Collectors.toMap(StackMapFrameInfo::target, StackMapFrameInfo::locals)))
191                 .orElse(Map.of());
192         var exceptionHandlers = com.exceptionHandlers();
193         maps = new ArrayList<>();
194         int pc = 0;
195         // Initialization of fixed slots
196         for (int slot = 0; slot < fixedSlots; slot++) {
197             getMap(slot).map.set(0);
198         }
199         // Filling the slots liveness maps
200         for (var e : com) {
201             switch(e) {
202                 case LabelTarget lt -> {
203                     for (var eh : exceptionHandlers) {
204                         if (eh.tryStart() == lt.label()) {
205                             mergeFrom(pc, eh.handler());
206                         }
207                     }
208                 }
209                 case LoadInstruction li ->
210                     load(pc, li.slot(), li.typeKind());
211                 case StoreInstruction si ->
212                     store(pc, si.slot(), si.typeKind());
213                 case IncrementInstruction ii ->
214                     loadSingle(pc, ii.slot());
215                 case BranchInstruction bi ->
216                     mergeFrom(pc, bi.target());
217                 case LookupSwitchInstruction si -> {
218                     mergeFrom(pc, si.defaultTarget());
219                     for (var sc : si.cases()) {
220                         mergeFrom(pc, sc.target());
221                     }
222                 }
223                 case TableSwitchInstruction si -> {
224                     mergeFrom(pc, si.defaultTarget());
225                     for (var sc : si.cases()) {
226                         mergeFrom(pc, sc.target());
227                     }
228                 }
229                 default -> pc--;
230             }
231             pc++;
232         }
233         // Initialization of slots mapping
234         slotMap = new int[maps.size()];
235         for (int slot = 0; slot < slotMap.length; slot++) {
236             slotMap[slot] = slot;
237         }
238         // Iterative compaction of slots
239         for (int targetSlot = 0; targetSlot < maps.size() - 1; targetSlot++) {
240             for (int sourceSlot = Math.max(targetSlot + 1, fixedSlots); sourceSlot < maps.size(); sourceSlot++) {
241                 Slot source = maps.get(sourceSlot);
242                 // Re-mapping single slot
243                 if (source.flags == 0) {
244                     Slot target = maps.get(targetSlot);
245                     if (!target.map.intersects(source.map)) {
246                         // Single re-mapping, merge of the liveness maps and shift of the following slots by 1 left
247                         target.map.or(source.map);
248                         maps.remove(sourceSlot);
249                         for (int slot = 0; slot < slotMap.length; slot++) {
250                             if (slotMap[slot] == sourceSlot) {
251                                 slotMap[slot] = targetSlot;
252                             } else if (slotMap[slot] > sourceSlot) {
253                                 slotMap[slot]--;
254                             }
255                         }
256                     }
257                 } else if (source.flags == 1 && sourceSlot > targetSlot + 1) {
258                     Slot source2 = maps.get(sourceSlot + 1);
259                     // Re-mapping distinct double slot
260                     if (source2.flags == 2) {
261                         Slot target = maps.get(targetSlot);
262                         Slot target2 = maps.get(targetSlot + 1);
263                         if (!target.map.intersects(source.map) && !target2.map.intersects(source2.map)) {
264                             // Double re-mapping, merge of the liveness maps and shift of the following slots by 2 left
265                             target.map.or(source.map);
266                             target2.map.or(source2.map);
267                             maps.remove(sourceSlot + 1);
268                             maps.remove(sourceSlot);
269                             for (int slot = 0; slot < slotMap.length; slot++) {
270                                 if (slotMap[slot] == sourceSlot) {
271                                     slotMap[slot] = targetSlot;
272                                 } else if (slotMap[slot] == sourceSlot + 1) {
273                                     slotMap[slot] = targetSlot + 1;
274                                 } else if (slotMap[slot] > sourceSlot + 1) {
275                                     slotMap[slot] -= 2;
276                                 }
277                             }
278                         }
279                     }
280                 }
281             }
282         }
283     }
284 
285     private Slot getMap(int slot) {
286         while (slot >= maps.size()) {
287             maps.add(new Slot());
288         }
289         return maps.get(slot);
290     }
291 
292     private Slot loadSingle(int pc, int slot) {
293         Slot s =  getMap(slot);
294         int start = s.map.nextSetBit(0) + 1;
295         s.map.set(start, pc + 1);
296         return s;
297     }
298 
299     private void load(int pc, int slot, TypeKind tk) {
300         load(pc, slot, tk.slotSize() == 2);
301     }
302 
303     private void load(int pc, int slot, boolean dual) {
304         if (dual) {
305             loadSingle(pc, slot).flags |= 1;
306             loadSingle(pc, slot + 1).flags |= 2;
307         } else {
308             loadSingle(pc, slot);
309         }
310     }
311 
312     private void mergeFrom(int pc, Label target) {
313         int slot = 0;
314         for (var vti : frames.get(target)) {
315             if (vti != TOP) {
316                 if (vti == LONG || vti == DOUBLE) {
317                     load(pc, slot++, true);
318                 } else {
319                     loadSingle(pc, slot);
320                 }
321             }
322             slot++;
323         }
324     }
325 
326     private Slot storeSingle(int pc, int slot) {
327         Slot s = getMap(slot);
328         s.map.set(pc);
329         return s;
330     }
331 
332     private void store(int pc, int slot, TypeKind tk) {
333         if (tk.slotSize() == 2) {
334             storeSingle(pc, slot).flags |= 1;
335             storeSingle(pc, slot + 1).flags |= 2;
336         } else {
337             storeSingle(pc, slot);
338         }
339     }
340 }