1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.analysis.Inliner;
 26 import jdk.incubator.code.dialect.core.CoreType;
 27 import jdk.incubator.code.dialect.java.JavaOp;
 28 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
 29 import jdk.incubator.code.dialect.java.ClassType;
 30 import jdk.incubator.code.dialect.java.JavaType;
 31 import java.util.ArrayList;
 32 import java.util.List;
 33 import java.util.function.*;
 34 
 35 import static jdk.incubator.code.dialect.core.CoreOp.*;
 36 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
 37 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
 38 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
 39 import static jdk.incubator.code.dialect.java.JavaType.type;
 40 
 41 public final class StreamFuserUsingQuotable {
 42 
 43     // Quotable functional interfaces
 44 
 45     public interface QuotablePredicate<T> extends Quotable, Predicate<T> {
 46     }
 47 
 48     public interface QuotableFunction<T, R> extends Quotable, Function<T, R> {
 49     }
 50 
 51     public interface QuotableSupplier<T> extends Quotable, Supplier<T> {
 52     }
 53 
 54     public interface QuotableConsumer<T> extends Quotable, Consumer<T> {
 55     }
 56 
 57     public interface QuotableBiConsumer<T, U> extends Quotable, BiConsumer<T, U> {
 58     }
 59 
 60 
 61     StreamFuserUsingQuotable() {}
 62 
 63     public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
 64         JavaType elementType = type(elementClass);
 65         // java.util.List<E>
 66         JavaType listType = parameterized(type(List.class), elementType);
 67         return new StreamExprBuilder<>(listType, elementType,
 68                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 69     }
 70 
 71     public static class StreamExprBuilder<T> {
 72         static class StreamOp {
 73             final JavaOp.LambdaOp lambdaOp;
 74 
 75             StreamOp(Quotable quotedLambda) {
 76                 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) {
 77                     throw new IllegalArgumentException("Quotable operation is not lambda operation");
 78                 }
 79                 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) {
 80                     throw new IllegalArgumentException("Quotable operation captures values");
 81                 }
 82                 this.lambdaOp = lambdaOp;
 83             }
 84 
 85             JavaOp.LambdaOp op() {
 86                 return lambdaOp;
 87             }
 88         }
 89 
 90         static class MapStreamOp extends StreamOp {
 91             public MapStreamOp(Quotable quotedLambda) {
 92                 super(quotedLambda);
 93             }
 94         }
 95 
 96         static class FlatMapStreamOp extends StreamOp {
 97             public FlatMapStreamOp(Quotable quotedLambda) {
 98                 super(quotedLambda);
 99             }
100         }
101 
102         static class FilterStreamOp extends StreamOp {
103             public FilterStreamOp(Quotable quotedLambda) {
104                 super(quotedLambda);
105             }
106         }
107 
108         final JavaType sourceType;
109         final JavaType sourceElementType;
110         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
111         final List<StreamOp> streamOps;
112 
113         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
114                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
115             this.sourceType = sourceType;
116             this.sourceElementType = sourceElementType;
117             this.loopSupplier = loopSupplier;
118             this.streamOps = new ArrayList<>();
119         }
120 
121         static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
122                                                          Value iterable) {
123             return enhancedFor(ancestorBody, iterable.type(), elementType)
124                     .expression(b -> {
125                         b.op(core_yield(iterable));
126                     })
127                     .definition(b -> {
128                         b.op(core_yield(b.parameters().get(0)));
129                     });
130         }
131 
132         @SuppressWarnings("unchecked")
133         public <R> StreamExprBuilder<R> map(QuotableFunction<T, R> f) {
134             streamOps.add(new MapStreamOp(f));
135             return (StreamExprBuilder<R>) this;
136         }
137 
138         @SuppressWarnings("unchecked")
139         public <R> StreamExprBuilder<R> flatMap(QuotableFunction<T, Iterable<R>> f) {
140             streamOps.add(new FlatMapStreamOp(f));
141             return (StreamExprBuilder<R>) this;
142         }
143 
144         public StreamExprBuilder<T> filter(QuotablePredicate<T> f) {
145             streamOps.add(new FilterStreamOp(f));
146             return this;
147         }
148 
149         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
150             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
151         }
152 
153         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
154                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
155             if (i == streamOps.size()) {
156                 terminalConsumer.accept(body, element);
157                 return;
158             }
159 
160             StreamOp sop = streamOps.get(i);
161             if (sop instanceof MapStreamOp) {
162                 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
163                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
164                 });
165             } else if (sop instanceof FilterStreamOp) {
166                 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
167                     Block.Builder _if = block.block();
168                     Block.Builder _else = continueBlock;
169                     if (continueBlock == null) {
170                         _else = block.block();
171                         _else.op(JavaOp.continue_());
172                     }
173 
174                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
175 
176                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
177                 });
178             } else if (sop instanceof FlatMapStreamOp) {
179                 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
180                     EnhancedForOp forOp = enhancedFor(block.parentBody(),
181                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
182                             .expression(b -> {
183                                 b.op(core_yield(iterable));
184                             })
185                             .definition(b -> {
186                                 b.op(core_yield(b.parameters().get(0)));
187                             })
188                             .body(b -> {
189                                 fuseIntermediateOperation(i + 1,
190                                         b,
191                                         b.parameters().get(0),
192                                         null, terminalConsumer);
193                             });
194 
195                     block.op(forOp);
196                     block.op(JavaOp.continue_());
197                 });
198             }
199         }
200 
201         public FuncOp forEach(QuotableConsumer<T> quotableConsumer) {
202             if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) {
203                 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
204             }
205             if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) {
206                 throw new IllegalArgumentException("Quotable consumer captures values");
207             }
208 
209             return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
210                     .body(b -> {
211                         Value source = b.parameters().get(0);
212 
213                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
214                                 .apply(loopBlock -> {
215                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
216                                         Inliner.inline(terminalBlock, consumer, List.of(resultValue),
217                                                 (_, _) -> {
218                                                 });
219                                         terminalBlock.op(JavaOp.continue_());
220                                     });
221 
222                                 });
223                         b.op(sourceLoop);
224                         b.op(return_());
225                     });
226         }
227 
228         public <C> FuncOp collect(QuotableSupplier<C> quotableSupplier, QuotableBiConsumer<C, T> quotableAccumulator) {
229             if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) {
230                 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
231             }
232             if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) {
233                 throw new IllegalArgumentException("Quotable supplier captures values");
234             }
235             if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) {
236                 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
237             }
238             if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) {
239                 throw new IllegalArgumentException("Quotable accumulator captures values");
240             }
241 
242             JavaType collectType = (JavaType) supplier.invokableType().returnType();
243             return func("fused.collect", CoreType.functionType(collectType, sourceType))
244                     .body(b -> {
245                         Value source = b.parameters().get(0);
246 
247                         Inliner.inline(b, supplier, List.of(), (block, collect) -> {
248                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
249                                     .apply(loopBlock -> {
250                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
251                                             Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
252                                                     (_, _) -> {
253                                                     });
254                                             terminalBlock.op(JavaOp.continue_());
255                                         });
256                                     });
257                             block.op(sourceLoop);
258                             block.op(return_(collect));
259                         });
260                     });
261         }
262 
263     }
264 }
265