1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.incubator.code.analysis;
 27 
 28 import jdk.incubator.code.Block;
 29 import jdk.incubator.code.Op;
 30 import jdk.incubator.code.CodeElement;
 31 import jdk.incubator.code.Value;
 32 import jdk.incubator.code.dialect.core.CoreOp;
 33 import java.util.*;
 34 import java.util.function.*;
 35 import java.util.stream.Gatherer;
 36 import java.util.stream.Gatherers;
 37 
 38 /**
 39  * A simple and experimental pattern match mechanism on values and operations.
 40  * <p>
 41  * When the language has support for pattern matching with matcher methods we should be able to express
 42  * matching on values and operations more powerfully and concisely.
 43  */
 44 public final class Patterns {
 45 
 46     private Patterns() {
 47     }
 48 
 49 
 50     /**
 51      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 52      * (have no uses) and are pure (are instances of {@code Op.Pure} and thus have no side effects).
 53      *
 54      * @param op the operation to traverse
 55      * @return the set of used and pure operations.
 56      */
 57     public static Set<Op> matchUnusedPureOps(Op op) {
 58         return matchUnusedPureOps(op, o -> o instanceof Op.Pure);
 59     }
 60 
 61     /**
 62      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 63      * (have no uses) and are pure (according to the given predicate).
 64      *
 65      * @param op       the operation to traverse
 66      * @param testPure the predicate to test if an operation is pure
 67      * @return the set of used and pure operations.
 68      */
 69     public static Set<Op> matchUnusedPureOps(Op op, Predicate<Op> testPure) {
 70         return match(
 71                 new HashSet<>(),
 72                 op, opP(o -> isDeadOp(o, testPure)),
 73                 (ms, deadOps) -> {
 74                     deadOps.add(ms.op());
 75 
 76                     // Dependent dead ops
 77                     matchDependentDeadOps(ms.op(), deadOps, testPure);
 78                     // @@@ No means to control traversal and only go deeper when
 79                     // there is only one user
 80 //                    ms.op().traverseOperands(null, (_a, arg) -> {
 81 //                        if (arg.result().users().size() == 1) {
 82 //                            deadOps.add(arg);
 83 //                        }
 84 //
 85 //                        return null;
 86 //                    });
 87 
 88                     return deadOps;
 89                 });
 90     }
 91 
 92     static boolean isDeadOp(Op op, Predicate<Op> testPure) {
 93         if (op instanceof Op.Terminating) {
 94             return false;
 95         }
 96 
 97         return op.result() != null && op.result().uses().isEmpty() && testPure.test(op);
 98     }
 99 
100     // @@@ this could be made generic with a method traversing up the model tree,
101     // the challenge is controlling when to keep traversing or not, and it may make
102     // it more complex that just writing it like below for specific cases
103     // A better option may be to provide a lazy stream of the values that can be filtered
104     // similar to CodeElement::elements
105     static void matchDependentDeadOps(Op op, Set<Op> deadOps, Predicate<Op> testPure) {
106         for (Value arg : op.operands()) {
107             if (arg instanceof Op.Result or) {
108                 if (arg.uses().size() == 1 && testPure.test(or.op())) {
109                     deadOps.add(or.op());
110 
111                     // Traverse only when a single user
112                     matchDependentDeadOps(or.op(), deadOps, testPure);
113                 }
114             }
115         }
116     }
117 
118 
119     // Matching of patterns
120 
121     /**
122      * The state of a successful match of an operation with matched operands (if any)
123      */
124     public static final class MatchState {
125         final Op op;
126         final List<Value> matchedOperands;
127 
128         MatchState(Op op, List<Value> matchedOperands) {
129             this.op = op;
130             this.matchedOperands = matchedOperands;
131         }
132 
133         /**
134          * {@return the matched operation}
135          */
136         public Op op() {
137             return op;
138         }
139 
140         /**
141          * {@return the matched operands}
142          */
143         public List<Value> matchedOperands() {
144             return matchedOperands;
145         }
146     }
147 
148     record PatternAndFunction<R>(OpPattern p, BiFunction<MatchState, R, R> f) {
149     }
150 
151     // Visiting op pattern matcher
152     static class OpPatternMatcher<R> implements BiFunction<R, Op, R> {
153         final List<PatternAndFunction<R>> patterns;
154         final PatternState state;
155         final Map<MatchState, BiFunction<MatchState, R, R>> matches;
156 
157         OpPatternMatcher(OpPattern p, BiFunction<MatchState, R, R> f) {
158             this(List.of(new PatternAndFunction<>(p, f)));
159         }
160 
161         OpPatternMatcher(List<PatternAndFunction<R>> patterns) {
162             this.patterns = patterns;
163             this.state = new PatternState();
164             this.matches = new HashMap<>();
165         }
166 
167         @Override
168         public R apply(R r, Op op) {
169             for (PatternAndFunction<R> pf : patterns) {
170                 if (pf.p.match(op, state)) {
171                     MatchState ms = new MatchState(op, state.resetOnMatch());
172 
173                     r = pf.f.apply(ms, r);
174                 } else {
175                     state.resetOnNoMatch();
176                 }
177             }
178 
179             return r;
180         }
181     }
182 
183     /**
184      * A match builder for declaring matching for one or more groups of operation patterns against a given traversable
185      * and descendant operations (in order).
186      * @param <R> the match result type
187      */
188     public static final class MultiMatchBuilder<R> {
189         final CodeElement<?, ?> o;
190         final R r;
191         List<PatternAndFunction<R>> patterns;
192 
193         MultiMatchBuilder(CodeElement<?, ?> o, R r) {
194             this.o = o;
195             this.r = r;
196             this.patterns = new ArrayList<>();
197         }
198 
199         /**
200          * Declares a first of possibly other operation patterns in a group.
201          *
202          * @param p the operation pattern
203          * @return a builder to declare further patterns in the group.
204          */
205         public MultiMatchCaseBuilder pattern(OpPattern p) {
206             return new MultiMatchCaseBuilder(p);
207         }
208 
209         public R matchThenApply() {
210             OpPatternMatcher<R> opm = new OpPatternMatcher<>(patterns);
211             return o.elements().gather(Gatherers.fold(
212                     () -> r,
213                     (r, e) -> e instanceof Op op ? opm.apply(r, op) : r)
214             ).findFirst().orElseThrow();
215         }
216 
217         /**
218          * A builder to declare further operation patterns in a group or to associate a
219          * target function to be applied if any of the patterns in the group match.
220          */
221         public final class MultiMatchCaseBuilder {
222             List<OpPattern> patterns;
223 
224             MultiMatchCaseBuilder(OpPattern p) {
225                 this.patterns = new ArrayList<>();
226                 patterns.add(p);
227             }
228 
229             /**
230              * Declares an operation pattern in the group.
231              *
232              * @param p the operation pattern
233              * @return this builder.
234              */
235             public MultiMatchCaseBuilder pattern(OpPattern p) {
236                 patterns.add(p);
237                 return this;
238             }
239 
240             /**
241              * Declares the target function to be applied if any of the operation patterns on the group match.
242              *
243              * @param f the target function.
244              * @return the match builder to build further groups.
245              */
246             public MultiMatchBuilder<R> target(BiFunction<MatchState, R, R> f) {
247                 patterns.stream().map(p -> new PatternAndFunction<>(p, f)).forEach(MultiMatchBuilder.this.patterns::add);
248                 return MultiMatchBuilder.this;
249             }
250         }
251     }
252 
253     /**
254      * Constructs a match builder from which to declare matching for one or more groups of operation patterns against a
255      * given traversable and descendant operations (in order).
256      *
257      * @param r   the initial match result
258      * @param t   the traversable
259      * @param <R> the match result type
260      * @return the match builder
261      */
262     public static <R> MultiMatchBuilder<R> multiMatch(R r, CodeElement<?, ?> t) {
263         return new MultiMatchBuilder<>(t, r);
264     }
265 
266     /**
267      * Matches an operation pattern against the given traversable and descendant operations (in order).
268      *
269      * @param init      the initial match result
270      * @param t         the traversable
271      * @param opPattern the operation pattern
272      * @param matcher   the function to be applied with a match state and the current match result when an
273      *                  encountered operation matches the operation pattern
274      * @param <R>       the match result type
275      * @return the match result
276      */
277     public static <R> R match(R init, CodeElement<?, ?> t, OpPattern opPattern,
278                               BiFunction<MatchState, R, R> matcher) {
279         OpPatternMatcher<R> opm = new OpPatternMatcher<>(opPattern, matcher);
280         return t.elements().gather(Gatherers.fold(
281                 () -> init,
282                 (r, e) -> e instanceof Op op ? opm.apply(r, op) : r)
283         ).findFirst().orElseThrow();
284     }
285 
286 
287     // Pattern classes
288 
289     static final class PatternState {
290         List<Value> matchedOperands;
291 
292         void addOperand(Value v) {
293             if (matchedOperands == null) {
294                 matchedOperands = new ArrayList<>();
295             }
296             matchedOperands.add(v);
297         }
298 
299         List<Value> resetOnMatch() {
300             if (matchedOperands != null) {
301                 List<Value> r = matchedOperands;
302                 matchedOperands = null;
303                 return r;
304             } else {
305                 return List.of();
306             }
307         }
308 
309         void resetOnNoMatch() {
310             if (matchedOperands != null) {
311                 matchedOperands.clear();
312             }
313         }
314     }
315 
316     /**
317      * A pattern matching against a value or operation.
318      */
319     public sealed static abstract class Pattern {
320         Pattern() {
321         }
322 
323         abstract boolean match(Value v, PatternState state);
324     }
325 
326     /**
327      * A pattern matching against an operation.
328      */
329     public static final class OpPattern extends Pattern {
330         final Predicate<Op> opTest;
331         final List<Pattern> operandPatterns;
332 
333         OpPattern(Predicate<Op> opTest, List<Pattern> operandPatterns) {
334             this.opTest = opTest;
335             this.operandPatterns = List.copyOf(operandPatterns);
336         }
337 
338         @Override
339         boolean match(Value v, PatternState state) {
340             if (v instanceof Op.Result or) {
341                 return match(or.op(), state);
342             } else {
343                 return false;
344             }
345         }
346 
347         boolean match(Op op, PatternState state) {
348             // Test does not match
349             if (!opTest.test(op)) {
350                 return false;
351             }
352 
353             if (!operandPatterns.isEmpty()) {
354                 // Arity does not match
355                 if (op.operands().size() != operandPatterns.size()) {
356                     return false;
357                 }
358 
359                 // Match all arguments
360                 for (int i = 0; i < operandPatterns.size(); i++) {
361                     Pattern p = operandPatterns.get(i);
362                     Value v = op.operands().get(i);
363 
364                     if (!p.match(v, state)) {
365                         return false;
366                     }
367                 }
368             }
369 
370             return true;
371         }
372     }
373 
374     /**
375      * A pattern that unconditionally matches a value which is captured. If the value is an operation result of an
376      * operation, then an operation pattern (if any) is further matched against the operation.
377      */
378     // @@@ type?
379     static final class ValuePattern extends Pattern {
380         final OpPattern opMatcher;
381 
382         ValuePattern() {
383             this(null);
384         }
385 
386         public ValuePattern(OpPattern opMatcher) {
387             this.opMatcher = opMatcher;
388         }
389 
390         @Override
391         boolean match(Value v, PatternState state) {
392             // Capture the operand
393             state.addOperand(v);
394 
395             // Match on operation on nested pattern, if any
396             return opMatcher == null || opMatcher.match(v, state);
397         }
398     }
399 
400     /**
401      * A pattern that conditionally matches an operation result which is captured,  then an operation pattern (if any)
402      * is further matched against the result's operation.
403      */
404     static final class OpResultPattern extends Pattern {
405         final OpPattern opMatcher;
406 
407         OpResultPattern() {
408             this(null);
409         }
410 
411         public OpResultPattern(OpPattern opMatcher) {
412             this.opMatcher = opMatcher;
413         }
414 
415         @Override
416         boolean match(Value v, PatternState state) {
417             if (!(v instanceof Op.Result)) {
418                 return false;
419             }
420 
421             // Capture the operand
422             state.addOperand(v);
423 
424             // Match on operation on nested pattern, if any
425             return opMatcher == null || opMatcher.match(v, state);
426         }
427     }
428 
429     /**
430      * A pattern that conditionally matches a block parameter which is captured.
431      */
432     static final class BlockParameterPattern extends Pattern {
433         BlockParameterPattern() {
434         }
435 
436         @Override
437         boolean match(Value v, PatternState state) {
438             if (!(v instanceof Block.Parameter)) {
439                 return false;
440             }
441 
442             // Capture the operand
443             state.addOperand(v);
444 
445             return true;
446         }
447     }
448 
449     /**
450      * A pattern matching any value or operation.
451      */
452     static final class AnyPattern extends Pattern {
453         AnyPattern() {
454         }
455 
456         @Override
457         boolean match(Value v, PatternState state) {
458             return true;
459         }
460     }
461 
462 
463     // Pattern factories
464 
465     /**
466      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
467      * {@code true}, matches operand patterns against the operation's operands (in order) .
468      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
469      * against the operation's operands.
470      *
471      * @param opTest the predicate
472      * @param patterns the operand patterns
473      * @return the operation pattern
474      */
475     public static OpPattern opP(Predicate<Op> opTest, Pattern... patterns) {
476         return opP(opTest, List.of(patterns));
477     }
478 
479     /**
480      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
481      * {@code true}, matches operand patterns against the operation's operands (in order) .
482      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
483      * against the operation's operands.
484      *
485      * @param opTest the predicate
486      * @param patterns the operand patterns
487      * @return the operation pattern
488      */
489     public static OpPattern opP(Predicate<Op> opTest, List<Pattern> patterns) {
490         return new OpPattern(opTest, patterns);
491     }
492 
493     /**
494      * Creates an operation pattern that tests if the operation is an instance of the class, and if
495      * {@code true}, matches operand patterns against the operation's operands (in order) .
496      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
497      * against the operation's operands.
498      *
499      * @param opClass the operation class
500      * @param patterns the operand patterns
501      * @return the operation pattern
502      */
503     public static OpPattern opP(Class<?> opClass, Pattern... patterns) {
504         return opP(opClass::isInstance, patterns);
505     }
506 
507     /**
508      * Creates an operation pattern that tests if the operation is a {@link CoreOp.ConstantOp constant} operation
509      * and whose constant value is equal to the given value.
510      * This operation pattern matches an operation if the test returns {@code true}.
511      *
512      * @param value the value
513      * @return the operation pattern.
514      */
515     public static OpPattern constantP(Object value) {
516         return opP(op -> {
517             if (op instanceof CoreOp.ConstantOp cop) {
518                 return Objects.equals(value, cop.value());
519             }
520 
521             return false;
522         });
523     }
524 
525     /**
526      * Creates a value pattern that unconditionally matches any value and captures the value in match state.
527      *
528      * @return the value pattern.
529      */
530     public static Pattern valueP() {
531         return new ValuePattern();
532     }
533 
534     /**
535      * Creates a value pattern that unconditionally matches any value and captures the value in match state, and
536      * if the value is an operation result of an operation, then the operation pattern is matched against that
537      * operation.
538      * This value pattern matches value if value is not an operation result, or otherwise matches if the operation
539      * pattern matches.
540      *
541      * @param opMatcher the operation pattern
542      * @return the value pattern.
543      */
544     public static Pattern valueP(OpPattern opMatcher) {
545         return new ValuePattern(opMatcher);
546     }
547 
548     /**
549      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state.
550      *
551      * @return the operation result.
552      */
553     public static Pattern opResultP() {
554         return new OpResultPattern();
555     }
556 
557     /**
558      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state,
559      * then the operation pattern is matched against the result's operation.
560      *
561      * @param opMatcher the operation pattern
562      * @return the operation result.
563      */
564     public static Pattern opResultP(OpPattern opMatcher) {
565         return new OpResultPattern(opMatcher);
566     }
567 
568     /**
569      * Creates a block parameter result pattern that conditionally matches a block parameter and captures it in match state.
570      *
571      * @return the block parameter.
572      */
573     public static Pattern blockParameterP() {
574         return new BlockParameterPattern();
575     }
576 
577     /**
578      * Creates a pattern that unconditionally matches any value or operation.
579      *
580      * @return the value pattern.
581      */
582     public static Pattern _P() {
583         return new AnyPattern();
584     }
585 }