1 /*
  2  * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.Block;
 25 import jdk.incubator.code.CodeElement;
 26 import jdk.incubator.code.Op;
 27 import jdk.incubator.code.Value;
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 
 30 import java.util.*;
 31 import java.util.function.BiFunction;
 32 import java.util.function.Predicate;
 33 import java.util.stream.Gatherers;
 34 
 35 /**
 36  * A simple and experimental pattern match mechanism on values and operations.
 37  * <p>
 38  * When the language has support for pattern matching with matcher methods we should be able to express
 39  * matching on values and operations more powerfully and concisely.
 40  */
 41 public final class Patterns {
 42 
 43     private Patterns() {
 44     }
 45 
 46 
 47     /**
 48      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 49      * (have no uses) and are pure (are instances of {@code Op.Pure} and thus have no side effects).
 50      *
 51      * @param op the operation to traverse
 52      * @return the set of used and pure operations.
 53      */
 54     public static Set<Op> matchUnusedPureOps(Op op) {
 55         return matchUnusedPureOps(op, o -> o instanceof Op.Pure);
 56     }
 57 
 58     /**
 59      * Traverses this operation and its descendant operations and returns the set of operations that are unused
 60      * (have no uses) and are pure (according to the given predicate).
 61      *
 62      * @param op       the operation to traverse
 63      * @param testPure the predicate to test if an operation is pure
 64      * @return the set of used and pure operations.
 65      */
 66     public static Set<Op> matchUnusedPureOps(Op op, Predicate<Op> testPure) {
 67         return match(
 68                 new HashSet<>(),
 69                 op, opP(o -> isDeadOp(o, testPure)),
 70                 (ms, deadOps) -> {
 71                     deadOps.add(ms.op());
 72 
 73                     // Dependent dead ops
 74                     matchDependentDeadOps(ms.op(), deadOps, testPure);
 75                     // @@@ No means to control traversal and only go deeper when
 76                     // there is only one user
 77 //                    ms.op().traverseOperands(null, (_a, arg) -> {
 78 //                        if (arg.result().users().size() == 1) {
 79 //                            deadOps.add(arg);
 80 //                        }
 81 //
 82 //                        return null;
 83 //                    });
 84 
 85                     return deadOps;
 86                 });
 87     }
 88 
 89     static boolean isDeadOp(Op op, Predicate<Op> testPure) {
 90         if (op instanceof Op.Terminating) {
 91             return false;
 92         }
 93 
 94         return op.result() != null && op.result().uses().isEmpty() && testPure.test(op);
 95     }
 96 
 97     // @@@ this could be made generic with a method traversing up the model tree,
 98     // the challenge is controlling when to keep traversing or not, and it may make
 99     // it more complex that just writing it like below for specific cases
100     // A better option may be to provide a lazy stream of the values that can be filtered
101     // similar to CodeElement::elements
102     static void matchDependentDeadOps(Op op, Set<Op> deadOps, Predicate<Op> testPure) {
103         for (Value arg : op.operands()) {
104             if (arg instanceof Op.Result or) {
105                 if (arg.uses().size() == 1 && testPure.test(or.op())) {
106                     deadOps.add(or.op());
107 
108                     // Traverse only when a single user
109                     matchDependentDeadOps(or.op(), deadOps, testPure);
110                 }
111             }
112         }
113     }
114 
115 
116     // Matching of patterns
117 
118     /**
119      * The state of a successful match of an operation with matched operands (if any)
120      */
121     public static final class MatchState {
122         final Op op;
123         final List<Value> matchedOperands;
124 
125         MatchState(Op op, List<Value> matchedOperands) {
126             this.op = op;
127             this.matchedOperands = matchedOperands;
128         }
129 
130         /**
131          * {@return the matched operation}
132          */
133         public Op op() {
134             return op;
135         }
136 
137         /**
138          * {@return the matched operands}
139          */
140         public List<Value> matchedOperands() {
141             return matchedOperands;
142         }
143     }
144 
145     record PatternAndFunction<R>(OpPattern p, BiFunction<MatchState, R, R> f) {
146     }
147 
148     // Visiting op pattern matcher
149     static class OpPatternMatcher<R> implements BiFunction<R, Op, R> {
150         final List<PatternAndFunction<R>> patterns;
151         final PatternState state;
152         final Map<MatchState, BiFunction<MatchState, R, R>> matches;
153 
154         OpPatternMatcher(OpPattern p, BiFunction<MatchState, R, R> f) {
155             this(List.of(new PatternAndFunction<>(p, f)));
156         }
157 
158         OpPatternMatcher(List<PatternAndFunction<R>> patterns) {
159             this.patterns = patterns;
160             this.state = new PatternState();
161             this.matches = new HashMap<>();
162         }
163 
164         @Override
165         public R apply(R r, Op op) {
166             for (PatternAndFunction<R> pf : patterns) {
167                 if (pf.p.match(op, state)) {
168                     MatchState ms = new MatchState(op, state.resetOnMatch());
169 
170                     r = pf.f.apply(ms, r);
171                 } else {
172                     state.resetOnNoMatch();
173                 }
174             }
175 
176             return r;
177         }
178     }
179 
180     /**
181      * A match builder for declaring matching for one or more groups of operation patterns against a given traversable
182      * and descendant operations (in order).
183      * @param <R> the match result type
184      */
185     public static final class MultiMatchBuilder<R> {
186         final CodeElement<?, ?> o;
187         final R r;
188         List<PatternAndFunction<R>> patterns;
189 
190         MultiMatchBuilder(CodeElement<?, ?> o, R r) {
191             this.o = o;
192             this.r = r;
193             this.patterns = new ArrayList<>();
194         }
195 
196         /**
197          * Declares a first of possibly other operation patterns in a group.
198          *
199          * @param p the operation pattern
200          * @return a builder to declare further patterns in the group.
201          */
202         public MultiMatchCaseBuilder pattern(OpPattern p) {
203             return new MultiMatchCaseBuilder(p);
204         }
205 
206         public R matchThenApply() {
207             OpPatternMatcher<R> opm = new OpPatternMatcher<>(patterns);
208             return o.elements().gather(Gatherers.fold(
209                     () -> r,
210                     (r, e) -> e instanceof Op op ? opm.apply(r, op) : r)
211             ).findFirst().orElseThrow();
212         }
213 
214         /**
215          * A builder to declare further operation patterns in a group or to associate a
216          * target function to be applied if any of the patterns in the group match.
217          */
218         public final class MultiMatchCaseBuilder {
219             List<OpPattern> patterns;
220 
221             MultiMatchCaseBuilder(OpPattern p) {
222                 this.patterns = new ArrayList<>();
223                 patterns.add(p);
224             }
225 
226             /**
227              * Declares an operation pattern in the group.
228              *
229              * @param p the operation pattern
230              * @return this builder.
231              */
232             public MultiMatchCaseBuilder pattern(OpPattern p) {
233                 patterns.add(p);
234                 return this;
235             }
236 
237             /**
238              * Declares the target function to be applied if any of the operation patterns on the group match.
239              *
240              * @param f the target function.
241              * @return the match builder to build further groups.
242              */
243             public MultiMatchBuilder<R> target(BiFunction<MatchState, R, R> f) {
244                 patterns.stream().map(p -> new PatternAndFunction<>(p, f)).forEach(MultiMatchBuilder.this.patterns::add);
245                 return MultiMatchBuilder.this;
246             }
247         }
248     }
249 
250     /**
251      * Constructs a match builder from which to declare matching for one or more groups of operation patterns against a
252      * given traversable and descendant operations (in order).
253      *
254      * @param r   the initial match result
255      * @param t   the traversable
256      * @param <R> the match result type
257      * @return the match builder
258      */
259     public static <R> MultiMatchBuilder<R> multiMatch(R r, CodeElement<?, ?> t) {
260         return new MultiMatchBuilder<>(t, r);
261     }
262 
263     /**
264      * Matches an operation pattern against the given traversable and descendant operations (in order).
265      *
266      * @param init      the initial match result
267      * @param t         the traversable
268      * @param opPattern the operation pattern
269      * @param matcher   the function to be applied with a match state and the current match result when an
270      *                  encountered operation matches the operation pattern
271      * @param <R>       the match result type
272      * @return the match result
273      */
274     public static <R> R match(R init, CodeElement<?, ?> t, OpPattern opPattern,
275                               BiFunction<MatchState, R, R> matcher) {
276         OpPatternMatcher<R> opm = new OpPatternMatcher<>(opPattern, matcher);
277         return t.elements().gather(Gatherers.fold(
278                 () -> init,
279                 (r, e) -> e instanceof Op op ? opm.apply(r, op) : r)
280         ).findFirst().orElseThrow();
281     }
282 
283 
284     // Pattern classes
285 
286     static final class PatternState {
287         List<Value> matchedOperands;
288 
289         void addOperand(Value v) {
290             if (matchedOperands == null) {
291                 matchedOperands = new ArrayList<>();
292             }
293             matchedOperands.add(v);
294         }
295 
296         List<Value> resetOnMatch() {
297             if (matchedOperands != null) {
298                 List<Value> r = matchedOperands;
299                 matchedOperands = null;
300                 return r;
301             } else {
302                 return List.of();
303             }
304         }
305 
306         void resetOnNoMatch() {
307             if (matchedOperands != null) {
308                 matchedOperands.clear();
309             }
310         }
311     }
312 
313     /**
314      * A pattern matching against a value or operation.
315      */
316     public sealed static abstract class Pattern {
317         Pattern() {
318         }
319 
320         abstract boolean match(Value v, PatternState state);
321     }
322 
323     /**
324      * A pattern matching against an operation.
325      */
326     public static final class OpPattern extends Pattern {
327         final Predicate<Op> opTest;
328         final List<Pattern> operandPatterns;
329 
330         OpPattern(Predicate<Op> opTest, List<Pattern> operandPatterns) {
331             this.opTest = opTest;
332             this.operandPatterns = List.copyOf(operandPatterns);
333         }
334 
335         @Override
336         boolean match(Value v, PatternState state) {
337             if (v instanceof Op.Result or) {
338                 return match(or.op(), state);
339             } else {
340                 return false;
341             }
342         }
343 
344         boolean match(Op op, PatternState state) {
345             // Test does not match
346             if (!opTest.test(op)) {
347                 return false;
348             }
349 
350             if (!operandPatterns.isEmpty()) {
351                 // Arity does not match
352                 if (op.operands().size() != operandPatterns.size()) {
353                     return false;
354                 }
355 
356                 // Match all arguments
357                 for (int i = 0; i < operandPatterns.size(); i++) {
358                     Pattern p = operandPatterns.get(i);
359                     Value v = op.operands().get(i);
360 
361                     if (!p.match(v, state)) {
362                         return false;
363                     }
364                 }
365             }
366 
367             return true;
368         }
369     }
370 
371     /**
372      * A pattern that unconditionally matches a value which is captured. If the value is an operation result of an
373      * operation, then an operation pattern (if any) is further matched against the operation.
374      */
375     // @@@ type?
376     static final class ValuePattern extends Pattern {
377         final OpPattern opMatcher;
378 
379         ValuePattern() {
380             this(null);
381         }
382 
383         public ValuePattern(OpPattern opMatcher) {
384             this.opMatcher = opMatcher;
385         }
386 
387         @Override
388         boolean match(Value v, PatternState state) {
389             // Capture the operand
390             state.addOperand(v);
391 
392             // Match on operation on nested pattern, if any
393             return opMatcher == null || opMatcher.match(v, state);
394         }
395     }
396 
397     /**
398      * A pattern that conditionally matches an operation result which is captured,  then an operation pattern (if any)
399      * is further matched against the result's operation.
400      */
401     static final class OpResultPattern extends Pattern {
402         final OpPattern opMatcher;
403 
404         OpResultPattern() {
405             this(null);
406         }
407 
408         public OpResultPattern(OpPattern opMatcher) {
409             this.opMatcher = opMatcher;
410         }
411 
412         @Override
413         boolean match(Value v, PatternState state) {
414             if (!(v instanceof Op.Result)) {
415                 return false;
416             }
417 
418             // Capture the operand
419             state.addOperand(v);
420 
421             // Match on operation on nested pattern, if any
422             return opMatcher == null || opMatcher.match(v, state);
423         }
424     }
425 
426     /**
427      * A pattern that conditionally matches a block parameter which is captured.
428      */
429     static final class BlockParameterPattern extends Pattern {
430         BlockParameterPattern() {
431         }
432 
433         @Override
434         boolean match(Value v, PatternState state) {
435             if (!(v instanceof Block.Parameter)) {
436                 return false;
437             }
438 
439             // Capture the operand
440             state.addOperand(v);
441 
442             return true;
443         }
444     }
445 
446     /**
447      * A pattern matching any value or operation.
448      */
449     static final class AnyPattern extends Pattern {
450         AnyPattern() {
451         }
452 
453         @Override
454         boolean match(Value v, PatternState state) {
455             return true;
456         }
457     }
458 
459 
460     // Pattern factories
461 
462     /**
463      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
464      * {@code true}, matches operand patterns against the operation's operands (in order) .
465      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
466      * against the operation's operands.
467      *
468      * @param opTest the predicate
469      * @param patterns the operand patterns
470      * @return the operation pattern
471      */
472     public static OpPattern opP(Predicate<Op> opTest, Pattern... patterns) {
473         return opP(opTest, List.of(patterns));
474     }
475 
476     /**
477      * Creates an operation pattern that tests against an operation by applying it to the predicate, and if
478      * {@code true}, matches operand patterns against the operation's operands (in order) .
479      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
480      * against the operation's operands.
481      *
482      * @param opTest the predicate
483      * @param patterns the operand patterns
484      * @return the operation pattern
485      */
486     public static OpPattern opP(Predicate<Op> opTest, List<Pattern> patterns) {
487         return new OpPattern(opTest, patterns);
488     }
489 
490     /**
491      * Creates an operation pattern that tests if the operation is an instance of the class, and if
492      * {@code true}, matches operand patterns against the operation's operands (in order) .
493      * This operation pattern matches an operation if the test returns {@code true} and all operand patterns match
494      * against the operation's operands.
495      *
496      * @param opClass the operation class
497      * @param patterns the operand patterns
498      * @return the operation pattern
499      */
500     public static OpPattern opP(Class<?> opClass, Pattern... patterns) {
501         return opP(opClass::isInstance, patterns);
502     }
503 
504     /**
505      * Creates an operation pattern that tests if the operation is a {@link CoreOp.ConstantOp constant} operation
506      * and whose constant value is equal to the given value.
507      * This operation pattern matches an operation if the test returns {@code true}.
508      *
509      * @param value the value
510      * @return the operation pattern.
511      */
512     public static OpPattern constantP(Object value) {
513         return opP(op -> {
514             if (op instanceof CoreOp.ConstantOp cop) {
515                 return Objects.equals(value, cop.value());
516             }
517 
518             return false;
519         });
520     }
521 
522     /**
523      * Creates a value pattern that unconditionally matches any value and captures the value in match state.
524      *
525      * @return the value pattern.
526      */
527     public static Pattern valueP() {
528         return new ValuePattern();
529     }
530 
531     /**
532      * Creates a value pattern that unconditionally matches any value and captures the value in match state, and
533      * if the value is an operation result of an operation, then the operation pattern is matched against that
534      * operation.
535      * This value pattern matches value if value is not an operation result, or otherwise matches if the operation
536      * pattern matches.
537      *
538      * @param opMatcher the operation pattern
539      * @return the value pattern.
540      */
541     public static Pattern valueP(OpPattern opMatcher) {
542         return new ValuePattern(opMatcher);
543     }
544 
545     /**
546      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state.
547      *
548      * @return the operation result.
549      */
550     public static Pattern opResultP() {
551         return new OpResultPattern();
552     }
553 
554     /**
555      * Creates an operation result pattern that conditionally matches an operation result and captures it in match state,
556      * then the operation pattern is matched against the result's operation.
557      *
558      * @param opMatcher the operation pattern
559      * @return the operation result.
560      */
561     public static Pattern opResultP(OpPattern opMatcher) {
562         return new OpResultPattern(opMatcher);
563     }
564 
565     /**
566      * Creates a block parameter result pattern that conditionally matches a block parameter and captures it in match state.
567      *
568      * @return the block parameter.
569      */
570     public static Pattern blockParameterP() {
571         return new BlockParameterPattern();
572     }
573 
574     /**
575      * Creates a pattern that unconditionally matches any value or operation.
576      *
577      * @return the value pattern.
578      */
579     public static Pattern _P() {
580         return new AnyPattern();
581     }
582 }