1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.incubator.code.analysis;
 27 
 28 import jdk.incubator.code.Block;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.CodeElement;
 31 import jdk.incubator.code.Value;
 32 import jdk.incubator.code.op.CoreOp;
 33 import java.util.*;
 34 import java.util.function.BiFunction;
 35 import java.util.function.Predicate;
 36 
 37 /**
 38  * A simple and experimental pattern match mechanism on values and operations.
 39  * <p>
 40  * When the language has support for pattern matching with matcher methods we should be able to express
 41  * matching on values and operations more powerfully and concisely.
 42  */
 43 public final class Patterns {
 44 
 45     private Patterns() {
 46     }
 47 
 48 
 49     /**
 50      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 51      * (have no uses) and are pure (are instances of {@code Op.Pure} and thus have no side effects).
 52      *
 53      * @param op the operation to traverse
 54      * @return the set of used and pure operations.
 55      */
 56     public static Set<Op> matchUnusedPureOps(Op op) {
 57         return matchUnusedPureOps(op, o -> o instanceof Op.Pure);
 58     }
 59 
 60     /**
 61      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 62      * (have no uses) and are pure (according to the given predicate).
 63      *
 64      * @param op       the operation to traverse
 65      * @param testPure the predicate to test if an operation is pure
 66      * @return the set of used and pure operations.
 67      */
 68     public static Set<Op> matchUnusedPureOps(Op op, Predicate<Op> testPure) {
 69         return match(
 70                 new HashSet<>(),
 71                 op, opP(o -> isDeadOp(o, testPure)),
 72                 (ms, deadOps) -> {
 73                     deadOps.add(ms.op());
 74 
 75                     // Dependent dead ops
 76                     matchDependentDeadOps(ms.op(), deadOps, testPure);
 77                     // @@@ No means to control traversal and only go deeper when
 78                     // there is only one user
 79 //                    ms.op().traverseOperands(null, (_a, arg) -> {
 80 //                        if (arg.result().users().size() == 1) {
 81 //                            deadOps.add(arg);
 82 //                        }
 83 //
 84 //                        return null;
 85 //                    });
 86 
 87                     return deadOps;
 88                 });
 89     }
 90 
 91     static boolean isDeadOp(Op op, Predicate<Op> testPure) {
 92         if (op instanceof Op.Terminating) {
 93             return false;
 94         }
 95 
 96         return op.result() != null && op.result().uses().isEmpty() && testPure.test(op);
 97     }
 98 
 99     // @@@ this could be made generic with a method traversing backwards
100     static void matchDependentDeadOps(Op op, Set<Op> deadOps, Predicate<Op> testPure) {
101         for (Value arg : op.operands()) {
102             if (arg instanceof Op.Result or) {
103                 if (arg.uses().size() == 1 && testPure.test(or.op())) {
104                     deadOps.add(or.op());
105 
106                     // Traverse only when a single user
107                     matchDependentDeadOps(or.op(), deadOps, testPure);
108                 }
109             }
110         }
111     }
112 
113 
114     // Matching of patterns
115 
116     /**
117      * The state of a successful match of an operation with matched operands (if any)
118      */
119     public static final class MatchState {
120         final Op op;
121         final List<Value> matchedOperands;
122 
123         MatchState(Op op, List<Value> matchedOperands) {
124             this.op = op;
125             this.matchedOperands = matchedOperands;
126         }
127 
128         /**
129          * {@return the matched operation}
130          */
131         public Op op() {
132             return op;
133         }
134 
135         /**
136          * {@return the matched operands}
137          */
138         public List<Value> matchedOperands() {
139             return matchedOperands;
140         }
141     }
142 
143     record PatternAndFunction<R>(OpPattern p, BiFunction<MatchState, R, R> f) {
144     }
145 
146     // Visiting op pattern matcher
147     static class OpPatternMatcher<R> implements BiFunction<R, Op, R> {
148         final List<PatternAndFunction<R>> patterns;
149         final PatternState state;
150         final Map<MatchState, BiFunction<MatchState, R, R>> matches;
151 
152         OpPatternMatcher(OpPattern p, BiFunction<MatchState, R, R> f) {
153             this(List.of(new PatternAndFunction<>(p, f)));
154         }
155 
156         OpPatternMatcher(List<PatternAndFunction<R>> patterns) {
157             this.patterns = patterns;
158             this.state = new PatternState();
159             this.matches = new HashMap<>();
160         }
161 
162         @Override
163         public R apply(R r, Op op) {
164             for (PatternAndFunction<R> pf : patterns) {
165                 if (pf.p.match(op, state)) {
166                     MatchState ms = new MatchState(op, state.resetOnMatch());
167 
168                     r = pf.f.apply(ms, r);
169                 } else {
170                     state.resetOnNoMatch();
171                 }
172             }
173 
174             return r;
175         }
176     }
177 
178     /**
179      * A match builder for declaring matching for one or more groups of operation patterns against a given traversable
180      * and descendant operations (in order).
181      * @param <R> the match result type
182      */
183     public static final class MultiMatchBuilder<R> {
184         final CodeElement<?, ?> o;
185         final R r;
186         List<PatternAndFunction<R>> patterns;
187 
188         MultiMatchBuilder(CodeElement<?, ?> o, R r) {
189             this.o = o;
190             this.r = r;
191             this.patterns = new ArrayList<>();
192         }
193 
194         /**
195          * Declares a first of possibly other operation patterns in a group.
196          *
197          * @param p the operation pattern
198          * @return a builder to declare further patterns in the group.
199          */
200         public MultiMatchCaseBuilder pattern(OpPattern p) {
201             return new MultiMatchCaseBuilder(p);
202         }
203 
204         public R matchThenApply() {
205             OpPatternMatcher<R> opm = new OpPatternMatcher<>(patterns);
206             return o.traverse(r, CodeElement.opVisitor(opm));
207         }
208 
209         /**
210          * A builder to declare further operation patterns in a group or to associate a
211          * target function to be applied if any of the patterns in the group match.
212          */
213         public final class MultiMatchCaseBuilder {
214             List<OpPattern> patterns;
215 
216             MultiMatchCaseBuilder(OpPattern p) {
217                 this.patterns = new ArrayList<>();
218                 patterns.add(p);
219             }
220 
221             /**
222              * Declares an operation pattern in the group.
223              *
224              * @param p the operation pattern
225              * @return this builder.
226              */
227             public MultiMatchCaseBuilder pattern(OpPattern p) {
228                 patterns.add(p);
229                 return this;
230             }
231 
232             /**
233              * Declares the target function to be applied if any of the operation patterns on the group match.
234              *
235              * @param f the target function.
236              * @return the match builder to build further groups.
237              */
238             public MultiMatchBuilder<R> target(BiFunction<MatchState, R, R> f) {
239                 patterns.stream().map(p -> new PatternAndFunction<>(p, f)).forEach(MultiMatchBuilder.this.patterns::add);
240                 return MultiMatchBuilder.this;
241             }
242         }
243     }
244 
245     /**
246      * Constructs a match builder from which to declare matching for one or more groups of operation patterns against a
247      * given traversable and descendant operations (in order).
248      *
249      * @param r   the initial match result
250      * @param t   the traversable
251      * @param <R> the match result type
252      * @return the match builder
253      */
254     public static <R> MultiMatchBuilder<R> multiMatch(R r, CodeElement<?, ?> t) {
255         return new MultiMatchBuilder<>(t, r);
256     }
257 
258     /**
259      * Matches an operation pattern against the given traversable and descendant operations (in order).
260      *
261      * @param r         the initial match result
262      * @param t         the traversable
263      * @param opPattern the operation pattern
264      * @param matcher   the function to be applied with a match state and the current match result when an
265      *                  encountered operation matches the operation pattern
266      * @param <R>       the match result type
267      * @return the match result
268      */
269     public static <R> R match(R r, CodeElement<?, ?> t, OpPattern opPattern,
270                               BiFunction<MatchState, R, R> matcher) {
271         OpPatternMatcher<R> opm = new OpPatternMatcher<>(opPattern, matcher);
272         return t.traverse(r, CodeElement.opVisitor(opm));
273     }
274 
275 
276     // Pattern classes
277 
278     static final class PatternState {
279         List<Value> matchedOperands;
280 
281         void addOperand(Value v) {
282             if (matchedOperands == null) {
283                 matchedOperands = new ArrayList<>();
284             }
285             matchedOperands.add(v);
286         }
287 
288         List<Value> resetOnMatch() {
289             if (matchedOperands != null) {
290                 List<Value> r = matchedOperands;
291                 matchedOperands = null;
292                 return r;
293             } else {
294                 return List.of();
295             }
296         }
297 
298         void resetOnNoMatch() {
299             if (matchedOperands != null) {
300                 matchedOperands.clear();
301             }
302         }
303     }
304 
305     /**
306      * A pattern matching against a value or operation.
307      */
308     public sealed static abstract class Pattern {
309         Pattern() {
310         }
311 
312         abstract boolean match(Value v, PatternState state);
313     }
314 
315     /**
316      * A pattern matching against an operation.
317      */
318     public static final class OpPattern extends Pattern {
319         final Predicate<Op> opTest;
320         final List<Pattern> operandPatterns;
321 
322         OpPattern(Predicate<Op> opTest, List<Pattern> operandPatterns) {
323             this.opTest = opTest;
324             this.operandPatterns = List.copyOf(operandPatterns);
325         }
326 
327         @Override
328         boolean match(Value v, PatternState state) {
329             if (v instanceof Op.Result or) {
330                 return match(or.op(), state);
331             } else {
332                 return false;
333             }
334         }
335 
336         boolean match(Op op, PatternState state) {
337             // Test does not match
338             if (!opTest.test(op)) {
339                 return false;
340             }
341 
342             if (!operandPatterns.isEmpty()) {
343                 // Arity does not match
344                 if (op.operands().size() != operandPatterns.size()) {
345                     return false;
346                 }
347 
348                 // Match all arguments
349                 for (int i = 0; i < operandPatterns.size(); i++) {
350                     Pattern p = operandPatterns.get(i);
351                     Value v = op.operands().get(i);
352 
353                     if (!p.match(v, state)) {
354                         return false;
355                     }
356                 }
357             }
358 
359             return true;
360         }
361     }
362 
363     /**
364      * A pattern that unconditionally matches a value which is captured. If the value is an operation result of an
365      * operation, then an operation pattern (if any) is further matched against the operation.
366      */
367     // @@@ type?
368     static final class ValuePattern extends Pattern {
369         final OpPattern opMatcher;
370 
371         ValuePattern() {
372             this(null);
373         }
374 
375         public ValuePattern(OpPattern opMatcher) {
376             this.opMatcher = opMatcher;
377         }
378 
379         @Override
380         boolean match(Value v, PatternState state) {
381             // Capture the operand
382             state.addOperand(v);
383 
384             // Match on operation on nested pattern, if any
385             return opMatcher == null || opMatcher.match(v, state);
386         }
387     }
388 
389     /**
390      * A pattern that conditionally matches an operation result which is captured,  then an operation pattern (if any)
391      * is further matched against the result's operation.
392      */
393     static final class OpResultPattern extends Pattern {
394         final OpPattern opMatcher;
395 
396         OpResultPattern() {
397             this(null);
398         }
399 
400         public OpResultPattern(OpPattern opMatcher) {
401             this.opMatcher = opMatcher;
402         }
403 
404         @Override
405         boolean match(Value v, PatternState state) {
406             if (!(v instanceof Op.Result)) {
407                 return false;
408             }
409 
410             // Capture the operand
411             state.addOperand(v);
412 
413             // Match on operation on nested pattern, if any
414             return opMatcher == null || opMatcher.match(v, state);
415         }
416     }
417 
418     /**
419      * A pattern that conditionally matches a block parameter which is captured.
420      */
421     static final class BlockParameterPattern extends Pattern {
422         BlockParameterPattern() {
423         }
424 
425         @Override
426         boolean match(Value v, PatternState state) {
427             if (!(v instanceof Block.Parameter)) {
428                 return false;
429             }
430 
431             // Capture the operand
432             state.addOperand(v);
433 
434             return true;
435         }
436     }
437 
438     /**
439      * A pattern matching any value or operation.
440      */
441     static final class AnyPattern extends Pattern {
442         AnyPattern() {
443         }
444 
445         @Override
446         boolean match(Value v, PatternState state) {
447             return true;
448         }
449     }
450 
451 
452     // Pattern factories
453 
454     /**
455      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
456      * {@code true}, matches operand patterns against the operation's operands (in order) .
457      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
458      * against the operation's operands.
459      *
460      * @param opTest the predicate
461      * @param patterns the operand patterns
462      * @return the operation pattern
463      */
464     public static OpPattern opP(Predicate<Op> opTest, Pattern... patterns) {
465         return opP(opTest, List.of(patterns));
466     }
467 
468     /**
469      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
470      * {@code true}, matches operand patterns against the operation's operands (in order) .
471      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
472      * against the operation's operands.
473      *
474      * @param opTest the predicate
475      * @param patterns the operand patterns
476      * @return the operation pattern
477      */
478     public static OpPattern opP(Predicate<Op> opTest, List<Pattern> patterns) {
479         return new OpPattern(opTest, patterns);
480     }
481 
482     /**
483      * Creates an operation pattern that tests if the operation is an instance of the class, and if
484      * {@code true}, matches operand patterns against the operation's operands (in order) .
485      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
486      * against the operation's operands.
487      *
488      * @param opClass the operation class
489      * @param patterns the operand patterns
490      * @return the operation pattern
491      */
492     public static OpPattern opP(Class<?> opClass, Pattern... patterns) {
493         return opP(opClass::isInstance, patterns);
494     }
495 
496     /**
497      * Creates an operation pattern that tests if the operation is a {@link CoreOp.ConstantOp constant} operation
498      * and whose constant value is equal to the given value.
499      * This operation pattern matches an operation if the test returns {@code true}.
500      *
501      * @param value the value
502      * @return the operation pattern.
503      */
504     public static OpPattern constantP(Object value) {
505         return opP(op -> {
506             if (op instanceof CoreOp.ConstantOp cop) {
507                 return Objects.equals(value, cop.value());
508             }
509 
510             return false;
511         });
512     }
513 
514     /**
515      * Creates a value pattern that unconditionally matches any value and captures the value in match state.
516      *
517      * @return the value pattern.
518      */
519     public static Pattern valueP() {
520         return new ValuePattern();
521     }
522 
523     /**
524      * Creates a value pattern that unconditionally matches any value and captures the value in match state, and
525      * if the value is an operation result of an operation, then the operation pattern is matched against that
526      * operation.
527      * This value pattern matches value if value is not an operation result, or otherwise matches if the operation
528      * pattern matches.
529      *
530      * @param opMatcher the operation pattern
531      * @return the value pattern.
532      */
533     public static Pattern valueP(OpPattern opMatcher) {
534         return new ValuePattern(opMatcher);
535     }
536 
537     /**
538      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state.
539      *
540      * @return the operation result.
541      */
542     public static Pattern opResultP() {
543         return new OpResultPattern();
544     }
545 
546     /**
547      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state,
548      * then the operation pattern is matched against the result's operation.
549      *
550      * @param opMatcher the operation pattern
551      * @return the operation result.
552      */
553     public static Pattern opResultP(OpPattern opMatcher) {
554         return new OpResultPattern(opMatcher);
555     }
556 
557     /**
558      * Creates a block parameter result pattern that conditionally matches a block parameter and captures it in match state.
559      *
560      * @return the block parameter.
561      */
562     public static Pattern blockParameterP() {
563         return new BlockParameterPattern();
564     }
565 
566     /**
567      * Creates a pattern that unconditionally matches any value or operation.
568      *
569      * @return the value pattern.
570      */
571     public static Pattern _P() {
572         return new AnyPattern();
573     }
574 }