1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package java.lang.reflect.code.bytecode;
 26 
 27 import java.lang.classfile.CodeElement;
 28 import java.lang.classfile.Instruction;
 29 import java.lang.classfile.Label;
 30 import java.lang.classfile.Opcode;
 31 import java.lang.classfile.TypeKind;
 32 import java.lang.classfile.attribute.StackMapFrameInfo;
 33 import java.lang.classfile.attribute.StackMapFrameInfo.*;
 34 import java.lang.classfile.attribute.StackMapTableAttribute;
 35 import java.lang.classfile.instruction.*;
 36 import java.lang.constant.ClassDesc;
 37 import java.lang.constant.ConstantDescs;
 38 import java.lang.constant.DirectMethodHandleDesc;
 39 import java.lang.constant.DynamicConstantDesc;
 40 import java.lang.constant.MethodTypeDesc;
 41 
 42 import java.util.ArrayList;
 43 import java.util.HashMap;
 44 import java.util.List;
 45 import java.util.Map;
 46 import java.util.Optional;
 47 import java.util.stream.Collectors;
 48 
 49 import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*;
 50 import static java.lang.constant.ConstantDescs.*;
 51 import java.lang.reflect.code.Value;
 52 import java.lang.reflect.code.type.JavaType;
 53 import java.util.ArrayDeque;
 54 import java.util.HashSet;
 55 import java.util.LinkedHashSet;
 56 import java.util.Set;
 57 
 58 final class LocalsTypeMapper {
 59 
 60     static class Variable {
 61         private ClassDesc type;
 62         boolean isSingleValue;
 63         Value value;
 64 
 65         JavaType type() {
 66             return JavaType.type(type);
 67         }
 68 
 69         Object defaultValue() {
 70             return switch (TypeKind.from(type)) {
 71                 case BooleanType -> false;
 72                 case ByteType -> (byte)0;
 73                 case CharType -> (char)0;
 74                 case DoubleType -> 0d;
 75                 case FloatType -> 0f;
 76                 case IntType -> 0;
 77                 case LongType -> 0l;
 78                 case ReferenceType -> null;
 79                 case ShortType -> (short)0;
 80                 default -> throw new IllegalStateException("Invalid type " + type.displayName());
 81             };
 82         }
 83     }
 84 
 85     static class Slot {
 86 
 87         record Link(Slot slot, Link other) {}
 88 
 89         ClassDesc type;
 90         Link up, down;
 91         Variable var;
 92         boolean newValue;
 93         Slot previous; // Previous Slot, not necessary of the same variable
 94 
 95         void link(Slot target) {
 96             if (this != target) {
 97                 target.up = new Link(this, target.up);
 98                 this.down = new Link(target, this.down);
 99             }
100         }
101     }
102 
103     record Frame(List<ClassDesc> stack, List<Slot> locals) {}
104 
105     private static final ClassDesc NULL_TYPE = ClassDesc.ofDescriptor(CD_Object.descriptorString());
106     private final Map<Integer, Slot> insMap;
107     private final LinkedHashSet<Slot> allSlots;
108     private final ClassDesc thisClass;
109     private final List<ExceptionCatch> exceptionHandlers;
110     private final List<ClassDesc> stack;
111     private final List<Slot> locals;
112     private final Map<Label, Frame> stackMap;
113     private final Map<Label, ClassDesc> newMap;
114     private boolean frameDirty;
115     final List<Slot> slotsToInitialize;
116 
117     LocalsTypeMapper(ClassDesc thisClass,
118                          List<ClassDesc> initFrameLocals,
119                          List<ExceptionCatch> exceptionHandlers,
120                          Optional<StackMapTableAttribute> stackMapTableAttribute,
121                          List<CodeElement> codeElements) {
122         this.insMap = new HashMap<>();
123         this.thisClass = thisClass;
124         this.exceptionHandlers = exceptionHandlers;
125         this.stack = new ArrayList<>();
126         this.locals = new ArrayList<>();
127         this.allSlots = new LinkedHashSet<>();
128         this.newMap = computeNewMap(codeElements);
129         this.slotsToInitialize = new ArrayList<>();
130         this.stackMap = stackMapTableAttribute.map(a -> a.entries().stream().collect(Collectors.toMap(
131                 StackMapFrameInfo::target,
132                 this::toFrame))).orElse(Map.of());
133         for (ClassDesc cd : initFrameLocals) {
134             slotsToInitialize.add(cd == null ? null : newSlot(cd, true));
135         }
136         int initSize = allSlots.size();
137         do {
138             // Slot states reset if running additional rounds with adjusted frames
139             if (allSlots.size() > initSize) {
140                 while (allSlots.size() > initSize) allSlots.removeLast();
141                 allSlots.forEach(sl -> {
142                     sl.up = null;
143                     sl.down = null;
144                     sl.previous = null;
145                     sl.var = null;
146                 });
147             }
148             for (int i = 0; i < initFrameLocals.size(); i++) {
149                 store(i, slotsToInitialize.get(i), locals);
150             }
151             this.frameDirty = false;
152             for (int i = 0; i < codeElements.size(); i++) {
153                 accept(i, codeElements.get(i));
154             }
155             endOfFlow();
156         } while (this.frameDirty);
157 
158         // Assign variable to slots, calculate var type, detect single value variables and dominant slot
159         ArrayDeque<Slot> q = new ArrayDeque<>();
160         Set<Slot> initialSlots = new HashSet<>();
161         for (Slot slot : allSlots) {
162             if (slot.var == null) {
163                 Variable var = new Variable();
164                 q.add(slot);
165                 int sources = 0;
166                 var.type = slot.type;
167                 while (!q.isEmpty()) {
168                     Slot sl = q.pop();
169                     if (sl.var == null) {
170                         if (sl.newValue) {
171                             sources++;
172                             if (sl.up == null) {
173                                 initialSlots.add(sl);
174                             }
175                         }
176                         sl.var = var;
177                         Slot.Link l = sl.up;
178                         while (l != null) {
179                             if (var.type == NULL_TYPE) var.type = l.slot.type;
180                             if (l.slot.var == null) q.add(l.slot);
181                             l = l.other;
182                         }
183                         l = sl.down;
184                         while (l != null) {
185                             if (var.type == NULL_TYPE) var.type = l.slot.type;
186                             if (l.slot.var == null) q.add(l.slot);
187                             l = l.other;
188                         }
189                     }
190                 }
191                 var.isSingleValue = sources < 2;
192 
193                 // Filter out slots, which are not initial (store into the same variable)
194                 for (var tsit = initialSlots.iterator(); tsit.hasNext();) {
195                     Slot sl = tsit.next();
196                     if (sl.previous != null && sl.previous.var == sl.var) {
197                         tsit.remove();
198                     }
199                 }
200                 if (initialSlots.size() > 1) {
201                     // Add synthetic dominant slot, which needs to be initialized with a default value
202                     Slot initialSlot = new Slot();
203                     initialSlot.var = var;
204                     slotsToInitialize.add(initialSlot);
205                     if (var.type == CD_long || var.type == CD_double) {
206                         slotsToInitialize.add(null);
207                     }
208                 }
209                 initialSlots.clear();
210             }
211         }
212     }
213 
214     private Frame toFrame(StackMapFrameInfo smfi) {
215         List<ClassDesc> fstack = new ArrayList<>(smfi.stack().size());
216         List<Slot> flocals = new ArrayList<>(smfi.locals().size() * 2);
217         for (var vti : smfi.stack()) {
218             fstack.add(vtiToStackType(vti));
219         }
220         int i = 0;
221         for (var vti : smfi.locals()) {
222             store(i, vtiToStackType(vti), flocals, false);
223             i += vti == ITEM_DOUBLE || vti == ITEM_LONG ? 2 : 1;
224         }
225         return new Frame(fstack, flocals);
226     }
227 
228     private static Map<Label, ClassDesc> computeNewMap(List<CodeElement> codeElements) {
229         Map<Label, ClassDesc> newMap = new HashMap<>();
230         Label lastLabel = null;
231         for (int i = 0; i < codeElements.size(); i++) {
232             switch (codeElements.get(i)) {
233                 case LabelTarget lt -> lastLabel = lt.label();
234                 case NewObjectInstruction newI -> {
235                     if (lastLabel != null) {
236                         newMap.put(lastLabel, newI.className().asSymbol());
237                     }
238                 }
239                 case Instruction _ -> lastLabel = null; //invalidate label
240                 default -> {} //skip
241             }
242         }
243         return newMap;
244     }
245 
246     Variable getVarOf(int li) {
247         return insMap.get(li).var;
248     }
249 
250     private Slot newSlot(ClassDesc type, boolean newValue) {
251         Slot s = new Slot();
252         s.type = type;
253         s.newValue = newValue;
254         allSlots.add(s);
255         return s;
256     }
257 
258     private ClassDesc vtiToStackType(StackMapFrameInfo.VerificationTypeInfo vti) {
259         return switch (vti) {
260             case ITEM_INTEGER -> CD_int;
261             case ITEM_FLOAT -> CD_float;
262             case ITEM_DOUBLE -> CD_double;
263             case ITEM_LONG -> CD_long;
264             case ITEM_UNINITIALIZED_THIS -> thisClass;
265             case ITEM_NULL -> NULL_TYPE;
266             case ObjectVerificationTypeInfo ovti -> ovti.classSymbol();
267             case UninitializedVerificationTypeInfo uvti ->
268                 newMap.computeIfAbsent(uvti.newTarget(), l -> {
269                     throw new IllegalArgumentException("Unitialized type does not point to a new instruction");
270                 });
271             case ITEM_TOP -> null;
272         };
273     }
274 
275     private void push(ClassDesc type) {
276         if (!ConstantDescs.CD_void.equals(type)) stack.add(type);
277     }
278 
279     private void pushAt(int pos, ClassDesc... types) {
280         for (var t : types)
281             if (!ConstantDescs.CD_void.equals(t))
282                 stack.add(stack.size() + pos, t);
283     }
284 
285     private boolean doubleAt(int pos) {
286         var t  = stack.get(stack.size() + pos);
287         return t.equals(CD_long) || t.equals(CD_double);
288     }
289 
290     private ClassDesc pop() {
291         return stack.removeLast();
292     }
293 
294     private ClassDesc get(int pos) {
295         return stack.get(stack.size() + pos);
296     }
297 
298     private ClassDesc top() {
299         return stack.getLast();
300     }
301 
302     private ClassDesc[] top2() {
303         return new ClassDesc[] {stack.get(stack.size() - 2), stack.getLast()};
304     }
305 
306     private LocalsTypeMapper pop(int i) {
307         while (i-- > 0) pop();
308         return this;
309     }
310 
311     private void store(int slot, ClassDesc type) {
312         store(slot, type, locals, true);
313     }
314 
315     private void store(int slot, ClassDesc type, List<Slot> where, boolean newValue) {
316         store(slot, type == null ? null : newSlot(type, newValue), where);
317     }
318 
319     private void store(int slot, Slot s, List<Slot> where) {
320         if (s != null) {
321             for (int i = where.size(); i <= slot; i++) where.add(null);
322             s.previous = where.set(slot, s);
323         }
324     }
325 
326     private ClassDesc load(int slot) {
327         return locals.get(slot).type;
328     }
329 
330     private void accept(int elIndex, CodeElement el) {
331         switch (el) {
332             case ArrayLoadInstruction _ ->
333                 pop(1).push(pop().componentType());
334             case ArrayStoreInstruction _ ->
335                 pop(3);
336             case BranchInstruction i -> {
337                 switch (i.opcode()) {
338                     case IFEQ, IFGE, IFGT, IFLE, IFLT, IFNE, IFNONNULL, IFNULL -> {
339                         pop();
340                         mergeToTargetFrame(i.target());
341                     }
342                     case IF_ACMPEQ, IF_ACMPNE, IF_ICMPEQ, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE, IF_ICMPLT, IF_ICMPNE -> {
343                         pop(2);
344                         mergeToTargetFrame(i.target());
345                     }
346                     case GOTO, GOTO_W -> {
347                         mergeToTargetFrame(i.target());
348                         endOfFlow();
349                     }
350                 }
351             }
352             case ConstantInstruction i ->
353                 push(switch (i.constantValue()) {
354                     case null -> NULL_TYPE;
355                     case ClassDesc _ -> CD_Class;
356                     case Double _ -> CD_double;
357                     case Float _ -> CD_float;
358                     case Integer _ -> CD_int;
359                     case Long _ -> CD_long;
360                     case String _ -> CD_String;
361                     case DynamicConstantDesc<?> cd when cd.equals(NULL) -> NULL_TYPE;
362                     case DynamicConstantDesc<?> cd -> cd.constantType();
363                     case DirectMethodHandleDesc _ -> CD_MethodHandle;
364                     case MethodTypeDesc _ -> CD_MethodType;
365                 });
366             case ConvertInstruction i ->
367                 pop(1).push(ClassDesc.ofDescriptor(i.toType().descriptor()));
368             case FieldInstruction i -> {
369                 switch (i.opcode()) {
370                     case GETSTATIC ->
371                         push(i.typeSymbol());
372                     case GETFIELD ->
373                         pop(1).push(i.typeSymbol());
374                     case PUTSTATIC ->
375                         pop(1);
376                     case PUTFIELD ->
377                         pop(2);
378                 }
379             }
380             case IncrementInstruction i -> {
381                 Slot v = locals.get(i.slot());
382                 store(i.slot(), load(i.slot()));
383                 v.link(locals.get(i.slot()));
384                 insMap.put(elIndex, v);
385             }
386             case InvokeDynamicInstruction i ->
387                 pop(i.typeSymbol().parameterCount()).push(i.typeSymbol().returnType());
388             case InvokeInstruction i ->
389                 pop(i.typeSymbol().parameterCount() + (i.opcode() == Opcode.INVOKESTATIC ? 0 : 1))
390                         .push(i.typeSymbol().returnType());
391             case LoadInstruction i -> {
392                 push(load(i.slot()));
393                 insMap.put(elIndex, locals.get(i.slot()));
394             }
395             case StoreInstruction i -> {
396                 store(i.slot(), pop());
397                 insMap.put(elIndex, locals.get(i.slot()));
398             }
399             case MonitorInstruction _ ->
400                 pop(1);
401             case NewMultiArrayInstruction i ->
402                 pop(i.dimensions()).push(i.arrayType().asSymbol());
403             case NewObjectInstruction i ->
404                 push(i.className().asSymbol());
405             case NewPrimitiveArrayInstruction i ->
406                 pop(1).push(ClassDesc.ofDescriptor(i.typeKind().descriptor()).arrayType());
407             case NewReferenceArrayInstruction i ->
408                 pop(1).push(i.componentType().asSymbol().arrayType());
409             case OperatorInstruction i ->
410                 pop(switch (i.opcode()) {
411                     case ARRAYLENGTH, INEG, LNEG, FNEG, DNEG -> 1;
412                     default -> 2;
413                 }).push(ClassDesc.ofDescriptor(i.typeKind().descriptor()));
414             case StackInstruction i -> {
415                 switch (i.opcode()) {
416                     case POP -> pop(1);
417                     case POP2 -> pop(doubleAt(-1) ? 1 : 2);
418                     case DUP -> push(top());
419                     case DUP2 -> {
420                         if (doubleAt(-1)) {
421                             push(top());
422                         } else {
423                             pushAt(-2, top2());
424                         }
425                     }
426                     case DUP_X1 -> pushAt(-2, top());
427                     case DUP_X2 -> pushAt(doubleAt(-2) ? -2 : -3, top());
428                     case DUP2_X1 -> {
429                         if (doubleAt(-1)) {
430                             pushAt(-2, top());
431                         } else {
432                             pushAt(-3, top2());
433                         }
434                     }
435                     case DUP2_X2 -> {
436                         if (doubleAt(-1)) {
437                             pushAt(doubleAt(-2) ? -2 : -3, top());
438                         } else {
439                             pushAt(doubleAt(-3) ? -3 : -4, top2());
440                         }
441                     }
442                     case SWAP -> pushAt(-1, pop());
443                 }
444             }
445             case TypeCheckInstruction i ->
446                 pop(1).push(i.opcode() == Opcode.CHECKCAST ? i.type().asSymbol() : ConstantDescs.CD_int);
447             case LabelTarget lt -> {
448                 var frame = stackMap.get(lt.label());
449                 if (frame != null) {
450                     if (!stack.isEmpty() || !locals.isEmpty()) {
451                         mergeToTargetFrame(lt.label());
452                         endOfFlow();
453                     }
454                     stack.addAll(frame.stack());
455                     locals.addAll(frame.locals());
456                 }
457                 for (ExceptionCatch ec : exceptionHandlers) {
458                     if (lt.label() == ec.tryStart()) {
459                         mergeLocalsToTargetFrame(stackMap.get(ec.handler()));
460                     }
461                 }
462             }
463             case ReturnInstruction _ , ThrowInstruction _ -> {
464                 endOfFlow();
465             }
466             case TableSwitchInstruction tsi -> {
467                 pop();
468                 mergeToTargetFrame(tsi.defaultTarget());
469                 for (var c : tsi.cases()) {
470                     mergeToTargetFrame(c.target());
471                 }
472                 endOfFlow();
473             }
474             case LookupSwitchInstruction lsi -> {
475                 pop();
476                 mergeToTargetFrame(lsi.defaultTarget());
477                 for (var c : lsi.cases()) {
478                     mergeToTargetFrame(c.target());
479                 }
480                 endOfFlow();
481             }
482             default -> {}
483         }
484     }
485 
486     private void endOfFlow() {
487         stack.clear();
488         locals.clear();
489     }
490 
491     private void mergeToTargetFrame(Label target) {
492         Frame targetFrame = stackMap.get(target);
493         // Merge stack
494         assert stack.size() == targetFrame.stack.size();
495         for (int i = 0; i < targetFrame.stack.size(); i++) {
496             ClassDesc se = stack.get(i);
497             ClassDesc fe = targetFrame.stack.get(i);
498             if (!se.equals(fe)) {
499                 if (se.isPrimitive() && CD_int.equals(fe)) {
500                     targetFrame.stack.set(i, se); // Override int target frame type with more specific int sub-type
501                     this.frameDirty = true;
502                 } else {
503                     stack.set(i, fe); // Override stack type with target frame type
504                 }
505             }
506         }
507         mergeLocalsToTargetFrame(targetFrame);
508     }
509 
510     private void mergeLocalsToTargetFrame(Frame targetFrame) {
511         // Merge locals
512         int lSize = Math.min(locals.size(), targetFrame.locals.size());
513         for (int i = 0; i < lSize; i++) {
514             Slot le = locals.get(i);
515             Slot fe = targetFrame.locals.get(i);
516             if (le != null && fe != null) {
517                 le.link(fe); // Link target frame var with its source
518                 if (!le.type.equals(fe.type)) {
519                     if (le.type.isPrimitive() && CD_int.equals(fe.type) ) {
520                         fe.type = le.type; // Override int target frame type with more specific int sub-type
521                         this.frameDirty = true;
522                     } else {
523                         le.type = fe.type; // Override var type with target frame type
524                     }
525                 }
526             }
527         }
528     }
529 }