1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package com.sun.tools.javac.comp;
  27 
  28 import com.sun.tools.javac.code.Flags;
  29 import com.sun.tools.javac.code.Symbol;
  30 import com.sun.tools.javac.code.Symbol.VarSymbol;
  31 import com.sun.tools.javac.code.Type;
  32 import com.sun.tools.javac.code.Types;
  33 import com.sun.tools.javac.resources.CompilerProperties.Errors;
  34 import com.sun.tools.javac.tree.JCTree;
  35 import com.sun.tools.javac.tree.JCTree.JCBinary;
  36 import com.sun.tools.javac.tree.JCTree.JCConditional;
  37 import com.sun.tools.javac.tree.JCTree.JCUnary;
  38 import com.sun.tools.javac.tree.JCTree.JCBindingPattern;
  39 import com.sun.tools.javac.tree.TreeScanner;
  40 import com.sun.tools.javac.util.Context;
  41 import com.sun.tools.javac.util.List;
  42 import com.sun.tools.javac.util.Log;
  43 import com.sun.tools.javac.util.Name;
  44 
  45 
  46 public class MatchBindingsComputer extends TreeScanner {
  47     protected static final Context.Key<MatchBindingsComputer> matchBindingsComputerKey = new Context.Key<>();
  48 
  49     private final Log log;
  50     private final Types types;
  51     boolean whenTrue;
  52     List<BindingSymbol> bindings;
  53 
  54     public static MatchBindingsComputer instance(Context context) {
  55         MatchBindingsComputer instance = context.get(matchBindingsComputerKey);
  56         if (instance == null)
  57             instance = new MatchBindingsComputer(context);
  58         return instance;
  59     }
  60 
  61     protected MatchBindingsComputer(Context context) {
  62         this.log = Log.instance(context);
  63         this.types = Types.instance(context);
  64     }
  65 
  66     public List<BindingSymbol> getMatchBindings(JCTree expression, boolean whenTrue) {
  67         this.whenTrue = whenTrue;
  68         this.bindings = List.nil();
  69         scan(expression);
  70         return bindings;
  71     }
  72 
  73     @Override
  74     public void visitBindingPattern(JCBindingPattern tree) {
  75         bindings = whenTrue ? List.of(tree.symbol) : List.nil();
  76     }
  77 
  78     @Override
  79     public void visitBinary(JCBinary tree) {
  80         switch (tree.getTag()) {
  81             case AND:
  82                 // e.T = union(x.T, y.T)
  83                 // e.F = intersection(x.F, y.F)
  84                 scan(tree.lhs);
  85                 List<BindingSymbol> lhsBindings = bindings;
  86                 scan(tree.rhs);
  87                 List<BindingSymbol> rhsBindings = bindings;
  88                 bindings = whenTrue ? union(tree, lhsBindings, rhsBindings) : intersection(tree, lhsBindings, rhsBindings);
  89                 break;
  90             case OR:
  91                 // e.T = intersection(x.T, y.T)
  92                 // e.F = union(x.F, y.F)
  93                 scan(tree.lhs);
  94                 lhsBindings = bindings;
  95                 scan(tree.rhs);
  96                 rhsBindings = bindings;
  97                 bindings = whenTrue ? intersection(tree, lhsBindings, rhsBindings) : union(tree, lhsBindings, rhsBindings);
  98                 break;
  99             default:
 100                 super.visitBinary(tree);
 101                 break;
 102         }
 103     }
 104 
 105     @Override
 106     public void visitUnary(JCUnary tree) {
 107         switch (tree.getTag()) {
 108             case NOT:
 109                 // e.T = x.F  // flip 'em
 110                 // e.F = x.T
 111                 whenTrue = !whenTrue;
 112                 scan(tree.arg);
 113                 whenTrue = !whenTrue;
 114                 break;
 115             default:
 116                 super.visitUnary(tree);
 117                 break;
 118         }
 119     }
 120 
 121     @Override
 122     public void visitConditional(JCConditional tree) {
 123         /* if e = "x ? y : z", then:
 124                e.T = union(intersect(y.T, z.T), intersect(x.T, z.T), intersect(x.F, y.T))
 125                e.F = union(intersect(y.F, z.F), intersect(x.T, z.F), intersect(x.F, y.F))
 126         */
 127         if (whenTrue) {
 128             List<BindingSymbol> xT, yT, zT, xF;
 129             scan(tree.cond);
 130             xT = bindings;
 131             scan(tree.truepart);
 132             yT = bindings;
 133             scan(tree.falsepart);
 134             zT = bindings;
 135             whenTrue = false;
 136             scan(tree.cond);
 137             xF = bindings;
 138             whenTrue = true;
 139             bindings = union(tree, intersection(tree, yT, zT), intersection(tree, xT, zT), intersection(tree, xF, yT));
 140         } else {
 141             List<BindingSymbol> xF, yF, zF, xT;
 142             scan(tree.cond);
 143             xF = bindings;
 144             scan(tree.truepart);
 145             yF = bindings;
 146             scan(tree.falsepart);
 147             zF = bindings;
 148             whenTrue = true;
 149             scan(tree.cond);
 150             xT = bindings;
 151             whenTrue = false;
 152             bindings = union(tree, intersection(tree, yF, zF), intersection(tree, xT, zF), intersection(tree, xF, yF));
 153         }
 154     }
 155 
 156     private List<BindingSymbol> intersection(JCTree tree, List<BindingSymbol> lhsBindings, List<BindingSymbol> rhsBindings) {
 157         // It is an error if, for intersection(a,b), if a and b contain the same variable name but with different types.
 158         List<BindingSymbol> list = List.nil();
 159         for (BindingSymbol v1 : lhsBindings) {
 160             for (BindingSymbol v2 : rhsBindings) {
 161                 if (v1.name == v2.name) {
 162                     if (types.isSameType(v1.type, v2.type)) {
 163                         list = list.append(new IntersectionBindingSymbol(List.of(v1, v2)));
 164                     } else {
 165                         log.error(tree.pos(), Errors.MatchBindingExistsWithDifferentType);
 166                     }
 167                 }
 168             }
 169         }
 170         return list;
 171     }
 172 
 173     @SafeVarargs
 174     private final List<BindingSymbol> union(JCTree tree, List<BindingSymbol> lhsBindings, List<BindingSymbol> ... rhsBindings_s) {
 175         // It is an error if for union(a,b), a and b contain the same name (disjoint union).
 176         List<BindingSymbol> list = lhsBindings;
 177         for (List<BindingSymbol> rhsBindings : rhsBindings_s) {
 178             for (BindingSymbol v : rhsBindings) {
 179                 for (BindingSymbol ov : list) {
 180                     if (ov.name == v.name) {
 181                         log.error(tree.pos(), Errors.MatchBindingExists);
 182                     }
 183                 }
 184                 list = list.append(v);
 185             }
 186         }
 187         return list;
 188     }
 189 
 190     @Override
 191     public void scan(JCTree tree) {
 192         bindings = List.nil();
 193         super.scan(tree);
 194     }
 195 
 196     public static class BindingSymbol extends VarSymbol {
 197 
 198         public BindingSymbol(Name name, Type type, Symbol owner) {
 199             super(Flags.FINAL | Flags.HASINIT | Flags.MATCH_BINDING, name, type, owner);
 200         }
 201 
 202         public boolean isAliasFor(BindingSymbol b) {
 203             return aliases().containsAll(b.aliases());
 204         }
 205 
 206         List<BindingSymbol> aliases() {
 207             return List.of(this);
 208         }
 209 
 210         public void preserveBinding() {
 211             flags_field |= Flags.MATCH_BINDING_TO_OUTER;
 212         }
 213 
 214         public boolean isPreserved() {
 215             return (flags_field & Flags.MATCH_BINDING_TO_OUTER) != 0;
 216         }
 217     }
 218 
 219     public static class IntersectionBindingSymbol extends BindingSymbol {
 220 
 221         List<BindingSymbol> aliases = List.nil();
 222 
 223         public IntersectionBindingSymbol(List<BindingSymbol> aliases) {
 224             super(aliases.head.name, aliases.head.type, aliases.head.owner);
 225             this.aliases = aliases.stream()
 226                     .flatMap(b -> b.aliases().stream())
 227                     .collect(List.collector());
 228         }
 229 
 230         @Override
 231         List<BindingSymbol> aliases() {
 232             return aliases;
 233         }
 234 
 235         @Override
 236         public void preserveBinding() {
 237             aliases.stream().forEach(BindingSymbol::preserveBinding);
 238         }
 239 
 240         public boolean isPreserved() {
 241             return aliases.stream().allMatch(BindingSymbol::isPreserved);
 242         }
 243     }
 244 }