1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.incubator.code.analysis;
 27 
 28 import jdk.incubator.code.Block;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.CodeElement;
 31 import jdk.incubator.code.Value;
 32 import jdk.incubator.code.dialect.core.CoreOp;
 33 import java.util.*;
 34 import java.util.function.BiFunction;
 35 import java.util.function.Predicate;
 36 
 37 /**
 38  * A simple and experimental pattern match mechanism on values and operations.
 39  * <p>
 40  * When the language has support for pattern matching with matcher methods we should be able to express
 41  * matching on values and operations more powerfully and concisely.
 42  */
 43 public final class Patterns {
 44 
 45     private Patterns() {
 46     }
 47 
 48 
 49     /**
 50      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 51      * (have no uses) and are pure (are instances of {@code Op.Pure} and thus have no side effects).
 52      *
 53      * @param op the operation to traverse
 54      * @return the set of used and pure operations.
 55      */
 56     public static Set<Op> matchUnusedPureOps(Op op) {
 57         return matchUnusedPureOps(op, o -> o instanceof Op.Pure);
 58     }
 59 
 60     /**
 61      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 62      * (have no uses) and are pure (according to the given predicate).
 63      *
 64      * @param op       the operation to traverse
 65      * @param testPure the predicate to test if an operation is pure
 66      * @return the set of used and pure operations.
 67      */
 68     public static Set<Op> matchUnusedPureOps(Op op, Predicate<Op> testPure) {
 69         return match(
 70                 new HashSet<>(),
 71                 op, opP(o -> isDeadOp(o, testPure)),
 72                 (ms, deadOps) -> {
 73                     deadOps.add(ms.op());
 74 
 75                     // Dependent dead ops
 76                     matchDependentDeadOps(ms.op(), deadOps, testPure);
 77                     // @@@ No means to control traversal and only go deeper when
 78                     // there is only one user
 79 //                    ms.op().traverseOperands(null, (_a, arg) -> {
 80 //                        if (arg.result().users().size() == 1) {
 81 //                            deadOps.add(arg);
 82 //                        }
 83 //
 84 //                        return null;
 85 //                    });
 86 
 87                     return deadOps;
 88                 });
 89     }
 90 
 91     static boolean isDeadOp(Op op, Predicate<Op> testPure) {
 92         if (op instanceof Op.Terminating) {
 93             return false;
 94         }
 95 
 96         return op.result() != null && op.result().uses().isEmpty() && testPure.test(op);
 97     }
 98 
 99     // @@@ this could be made generic with a method traversing up the model tree,
100     // the challenge is controlling when to keep traversing or not, and it may make
101     // it more complex that just writing it like below for specific cases
102     // A better option may be to provide a lazy stream of the values that can be filtered
103     // similar to CodeElement::elements
104     static void matchDependentDeadOps(Op op, Set<Op> deadOps, Predicate<Op> testPure) {
105         for (Value arg : op.operands()) {
106             if (arg instanceof Op.Result or) {
107                 if (arg.uses().size() == 1 && testPure.test(or.op())) {
108                     deadOps.add(or.op());
109 
110                     // Traverse only when a single user
111                     matchDependentDeadOps(or.op(), deadOps, testPure);
112                 }
113             }
114         }
115     }
116 
117 
118     // Matching of patterns
119 
120     /**
121      * The state of a successful match of an operation with matched operands (if any)
122      */
123     public static final class MatchState {
124         final Op op;
125         final List<Value> matchedOperands;
126 
127         MatchState(Op op, List<Value> matchedOperands) {
128             this.op = op;
129             this.matchedOperands = matchedOperands;
130         }
131 
132         /**
133          * {@return the matched operation}
134          */
135         public Op op() {
136             return op;
137         }
138 
139         /**
140          * {@return the matched operands}
141          */
142         public List<Value> matchedOperands() {
143             return matchedOperands;
144         }
145     }
146 
147     record PatternAndFunction<R>(OpPattern p, BiFunction<MatchState, R, R> f) {
148     }
149 
150     // Visiting op pattern matcher
151     static class OpPatternMatcher<R> implements BiFunction<R, Op, R> {
152         final List<PatternAndFunction<R>> patterns;
153         final PatternState state;
154         final Map<MatchState, BiFunction<MatchState, R, R>> matches;
155 
156         OpPatternMatcher(OpPattern p, BiFunction<MatchState, R, R> f) {
157             this(List.of(new PatternAndFunction<>(p, f)));
158         }
159 
160         OpPatternMatcher(List<PatternAndFunction<R>> patterns) {
161             this.patterns = patterns;
162             this.state = new PatternState();
163             this.matches = new HashMap<>();
164         }
165 
166         @Override
167         public R apply(R r, Op op) {
168             for (PatternAndFunction<R> pf : patterns) {
169                 if (pf.p.match(op, state)) {
170                     MatchState ms = new MatchState(op, state.resetOnMatch());
171 
172                     r = pf.f.apply(ms, r);
173                 } else {
174                     state.resetOnNoMatch();
175                 }
176             }
177 
178             return r;
179         }
180     }
181 
182     /**
183      * A match builder for declaring matching for one or more groups of operation patterns against a given traversable
184      * and descendant operations (in order).
185      * @param <R> the match result type
186      */
187     public static final class MultiMatchBuilder<R> {
188         final CodeElement<?, ?> o;
189         final R r;
190         List<PatternAndFunction<R>> patterns;
191 
192         MultiMatchBuilder(CodeElement<?, ?> o, R r) {
193             this.o = o;
194             this.r = r;
195             this.patterns = new ArrayList<>();
196         }
197 
198         /**
199          * Declares a first of possibly other operation patterns in a group.
200          *
201          * @param p the operation pattern
202          * @return a builder to declare further patterns in the group.
203          */
204         public MultiMatchCaseBuilder pattern(OpPattern p) {
205             return new MultiMatchCaseBuilder(p);
206         }
207 
208         public R matchThenApply() {
209             OpPatternMatcher<R> opm = new OpPatternMatcher<>(patterns);
210             return o.traverse(r, CodeElement.opVisitor(opm));
211         }
212 
213         /**
214          * A builder to declare further operation patterns in a group or to associate a
215          * target function to be applied if any of the patterns in the group match.
216          */
217         public final class MultiMatchCaseBuilder {
218             List<OpPattern> patterns;
219 
220             MultiMatchCaseBuilder(OpPattern p) {
221                 this.patterns = new ArrayList<>();
222                 patterns.add(p);
223             }
224 
225             /**
226              * Declares an operation pattern in the group.
227              *
228              * @param p the operation pattern
229              * @return this builder.
230              */
231             public MultiMatchCaseBuilder pattern(OpPattern p) {
232                 patterns.add(p);
233                 return this;
234             }
235 
236             /**
237              * Declares the target function to be applied if any of the operation patterns on the group match.
238              *
239              * @param f the target function.
240              * @return the match builder to build further groups.
241              */
242             public MultiMatchBuilder<R> target(BiFunction<MatchState, R, R> f) {
243                 patterns.stream().map(p -> new PatternAndFunction<>(p, f)).forEach(MultiMatchBuilder.this.patterns::add);
244                 return MultiMatchBuilder.this;
245             }
246         }
247     }
248 
249     /**
250      * Constructs a match builder from which to declare matching for one or more groups of operation patterns against a
251      * given traversable and descendant operations (in order).
252      *
253      * @param r   the initial match result
254      * @param t   the traversable
255      * @param <R> the match result type
256      * @return the match builder
257      */
258     public static <R> MultiMatchBuilder<R> multiMatch(R r, CodeElement<?, ?> t) {
259         return new MultiMatchBuilder<>(t, r);
260     }
261 
262     /**
263      * Matches an operation pattern against the given traversable and descendant operations (in order).
264      *
265      * @param r         the initial match result
266      * @param t         the traversable
267      * @param opPattern the operation pattern
268      * @param matcher   the function to be applied with a match state and the current match result when an
269      *                  encountered operation matches the operation pattern
270      * @param <R>       the match result type
271      * @return the match result
272      */
273     public static <R> R match(R r, CodeElement<?, ?> t, OpPattern opPattern,
274                               BiFunction<MatchState, R, R> matcher) {
275         OpPatternMatcher<R> opm = new OpPatternMatcher<>(opPattern, matcher);
276         return t.traverse(r, CodeElement.opVisitor(opm));
277     }
278 
279 
280     // Pattern classes
281 
282     static final class PatternState {
283         List<Value> matchedOperands;
284 
285         void addOperand(Value v) {
286             if (matchedOperands == null) {
287                 matchedOperands = new ArrayList<>();
288             }
289             matchedOperands.add(v);
290         }
291 
292         List<Value> resetOnMatch() {
293             if (matchedOperands != null) {
294                 List<Value> r = matchedOperands;
295                 matchedOperands = null;
296                 return r;
297             } else {
298                 return List.of();
299             }
300         }
301 
302         void resetOnNoMatch() {
303             if (matchedOperands != null) {
304                 matchedOperands.clear();
305             }
306         }
307     }
308 
309     /**
310      * A pattern matching against a value or operation.
311      */
312     public sealed static abstract class Pattern {
313         Pattern() {
314         }
315 
316         abstract boolean match(Value v, PatternState state);
317     }
318 
319     /**
320      * A pattern matching against an operation.
321      */
322     public static final class OpPattern extends Pattern {
323         final Predicate<Op> opTest;
324         final List<Pattern> operandPatterns;
325 
326         OpPattern(Predicate<Op> opTest, List<Pattern> operandPatterns) {
327             this.opTest = opTest;
328             this.operandPatterns = List.copyOf(operandPatterns);
329         }
330 
331         @Override
332         boolean match(Value v, PatternState state) {
333             if (v instanceof Op.Result or) {
334                 return match(or.op(), state);
335             } else {
336                 return false;
337             }
338         }
339 
340         boolean match(Op op, PatternState state) {
341             // Test does not match
342             if (!opTest.test(op)) {
343                 return false;
344             }
345 
346             if (!operandPatterns.isEmpty()) {
347                 // Arity does not match
348                 if (op.operands().size() != operandPatterns.size()) {
349                     return false;
350                 }
351 
352                 // Match all arguments
353                 for (int i = 0; i < operandPatterns.size(); i++) {
354                     Pattern p = operandPatterns.get(i);
355                     Value v = op.operands().get(i);
356 
357                     if (!p.match(v, state)) {
358                         return false;
359                     }
360                 }
361             }
362 
363             return true;
364         }
365     }
366 
367     /**
368      * A pattern that unconditionally matches a value which is captured. If the value is an operation result of an
369      * operation, then an operation pattern (if any) is further matched against the operation.
370      */
371     // @@@ type?
372     static final class ValuePattern extends Pattern {
373         final OpPattern opMatcher;
374 
375         ValuePattern() {
376             this(null);
377         }
378 
379         public ValuePattern(OpPattern opMatcher) {
380             this.opMatcher = opMatcher;
381         }
382 
383         @Override
384         boolean match(Value v, PatternState state) {
385             // Capture the operand
386             state.addOperand(v);
387 
388             // Match on operation on nested pattern, if any
389             return opMatcher == null || opMatcher.match(v, state);
390         }
391     }
392 
393     /**
394      * A pattern that conditionally matches an operation result which is captured,  then an operation pattern (if any)
395      * is further matched against the result's operation.
396      */
397     static final class OpResultPattern extends Pattern {
398         final OpPattern opMatcher;
399 
400         OpResultPattern() {
401             this(null);
402         }
403 
404         public OpResultPattern(OpPattern opMatcher) {
405             this.opMatcher = opMatcher;
406         }
407 
408         @Override
409         boolean match(Value v, PatternState state) {
410             if (!(v instanceof Op.Result)) {
411                 return false;
412             }
413 
414             // Capture the operand
415             state.addOperand(v);
416 
417             // Match on operation on nested pattern, if any
418             return opMatcher == null || opMatcher.match(v, state);
419         }
420     }
421 
422     /**
423      * A pattern that conditionally matches a block parameter which is captured.
424      */
425     static final class BlockParameterPattern extends Pattern {
426         BlockParameterPattern() {
427         }
428 
429         @Override
430         boolean match(Value v, PatternState state) {
431             if (!(v instanceof Block.Parameter)) {
432                 return false;
433             }
434 
435             // Capture the operand
436             state.addOperand(v);
437 
438             return true;
439         }
440     }
441 
442     /**
443      * A pattern matching any value or operation.
444      */
445     static final class AnyPattern extends Pattern {
446         AnyPattern() {
447         }
448 
449         @Override
450         boolean match(Value v, PatternState state) {
451             return true;
452         }
453     }
454 
455 
456     // Pattern factories
457 
458     /**
459      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
460      * {@code true}, matches operand patterns against the operation's operands (in order) .
461      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
462      * against the operation's operands.
463      *
464      * @param opTest the predicate
465      * @param patterns the operand patterns
466      * @return the operation pattern
467      */
468     public static OpPattern opP(Predicate<Op> opTest, Pattern... patterns) {
469         return opP(opTest, List.of(patterns));
470     }
471 
472     /**
473      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
474      * {@code true}, matches operand patterns against the operation's operands (in order) .
475      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
476      * against the operation's operands.
477      *
478      * @param opTest the predicate
479      * @param patterns the operand patterns
480      * @return the operation pattern
481      */
482     public static OpPattern opP(Predicate<Op> opTest, List<Pattern> patterns) {
483         return new OpPattern(opTest, patterns);
484     }
485 
486     /**
487      * Creates an operation pattern that tests if the operation is an instance of the class, and if
488      * {@code true}, matches operand patterns against the operation's operands (in order) .
489      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
490      * against the operation's operands.
491      *
492      * @param opClass the operation class
493      * @param patterns the operand patterns
494      * @return the operation pattern
495      */
496     public static OpPattern opP(Class<?> opClass, Pattern... patterns) {
497         return opP(opClass::isInstance, patterns);
498     }
499 
500     /**
501      * Creates an operation pattern that tests if the operation is a {@link CoreOp.ConstantOp constant} operation
502      * and whose constant value is equal to the given value.
503      * This operation pattern matches an operation if the test returns {@code true}.
504      *
505      * @param value the value
506      * @return the operation pattern.
507      */
508     public static OpPattern constantP(Object value) {
509         return opP(op -> {
510             if (op instanceof CoreOp.ConstantOp cop) {
511                 return Objects.equals(value, cop.value());
512             }
513 
514             return false;
515         });
516     }
517 
518     /**
519      * Creates a value pattern that unconditionally matches any value and captures the value in match state.
520      *
521      * @return the value pattern.
522      */
523     public static Pattern valueP() {
524         return new ValuePattern();
525     }
526 
527     /**
528      * Creates a value pattern that unconditionally matches any value and captures the value in match state, and
529      * if the value is an operation result of an operation, then the operation pattern is matched against that
530      * operation.
531      * This value pattern matches value if value is not an operation result, or otherwise matches if the operation
532      * pattern matches.
533      *
534      * @param opMatcher the operation pattern
535      * @return the value pattern.
536      */
537     public static Pattern valueP(OpPattern opMatcher) {
538         return new ValuePattern(opMatcher);
539     }
540 
541     /**
542      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state.
543      *
544      * @return the operation result.
545      */
546     public static Pattern opResultP() {
547         return new OpResultPattern();
548     }
549 
550     /**
551      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state,
552      * then the operation pattern is matched against the result's operation.
553      *
554      * @param opMatcher the operation pattern
555      * @return the operation result.
556      */
557     public static Pattern opResultP(OpPattern opMatcher) {
558         return new OpResultPattern(opMatcher);
559     }
560 
561     /**
562      * Creates a block parameter result pattern that conditionally matches a block parameter and captures it in match state.
563      *
564      * @return the block parameter.
565      */
566     public static Pattern blockParameterP() {
567         return new BlockParameterPattern();
568     }
569 
570     /**
571      * Creates a pattern that unconditionally matches any value or operation.
572      *
573      * @return the value pattern.
574      */
575     public static Pattern _P() {
576         return new AnyPattern();
577     }
578 }