1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.CoreType;
 26 import jdk.incubator.code.dialect.java.JavaOp;
 27 import jdk.incubator.code.dialect.java.JavaOp.JavaEnhancedForOp;
 28 import jdk.incubator.code.dialect.java.ClassType;
 29 import jdk.incubator.code.dialect.java.JavaType;
 30 import java.util.ArrayList;
 31 import java.util.List;
 32 import java.util.function.*;
 33 
 34 import static jdk.incubator.code.dialect.core.CoreOp.*;
 35 import static jdk.incubator.code.dialect.java.JavaOp._continue;
 36 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
 37 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
 38 import static jdk.incubator.code.dialect.java.JavaType.type;
 39 
 40 public final class StreamFuserUsingQuotable {
 41 
 42     // Quotable functional interfaces
 43 
 44     public interface QuotablePredicate<T> extends Quotable, Predicate<T> {
 45     }
 46 
 47     public interface QuotableFunction<T, R> extends Quotable, Function<T, R> {
 48     }
 49 
 50     public interface QuotableSupplier<T> extends Quotable, Supplier<T> {
 51     }
 52 
 53     public interface QuotableConsumer<T> extends Quotable, Consumer<T> {
 54     }
 55 
 56     public interface QuotableBiConsumer<T, U> extends Quotable, BiConsumer<T, U> {
 57     }
 58 
 59 
 60     StreamFuserUsingQuotable() {}
 61 
 62     public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
 63         JavaType elementType = type(elementClass);
 64         // java.util.List<E>
 65         JavaType listType = parameterized(type(List.class), elementType);
 66         return new StreamExprBuilder<>(listType, elementType,
 67                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 68     }
 69 
 70     public static class StreamExprBuilder<T> {
 71         static class StreamOp {
 72             final JavaOp.LambdaOp lambdaOp;
 73 
 74             StreamOp(Quotable quotedLambda) {
 75                 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) {
 76                     throw new IllegalArgumentException("Quotable operation is not lambda operation");
 77                 }
 78                 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) {
 79                     throw new IllegalArgumentException("Quotable operation captures values");
 80                 }
 81                 this.lambdaOp = lambdaOp;
 82             }
 83 
 84             JavaOp.LambdaOp op() {
 85                 return lambdaOp;
 86             }
 87         }
 88 
 89         static class MapStreamOp extends StreamOp {
 90             public MapStreamOp(Quotable quotedLambda) {
 91                 super(quotedLambda);
 92             }
 93         }
 94 
 95         static class FlatMapStreamOp extends StreamOp {
 96             public FlatMapStreamOp(Quotable quotedLambda) {
 97                 super(quotedLambda);
 98             }
 99         }
100 
101         static class FilterStreamOp extends StreamOp {
102             public FilterStreamOp(Quotable quotedLambda) {
103                 super(quotedLambda);
104             }
105         }
106 
107         final JavaType sourceType;
108         final JavaType sourceElementType;
109         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
110         final List<StreamOp> streamOps;
111 
112         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
113                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
114             this.sourceType = sourceType;
115             this.sourceElementType = sourceElementType;
116             this.loopSupplier = loopSupplier;
117             this.streamOps = new ArrayList<>();
118         }
119 
120         static JavaEnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
121                                                              Value iterable) {
122             return enhancedFor(ancestorBody, iterable.type(), elementType)
123                     .expression(b -> {
124                         b.op(_yield(iterable));
125                     })
126                     .definition(b -> {
127                         b.op(_yield(b.parameters().get(0)));
128                     });
129         }
130 
131         @SuppressWarnings("unchecked")
132         public <R> StreamExprBuilder<R> map(QuotableFunction<T, R> f) {
133             streamOps.add(new MapStreamOp(f));
134             return (StreamExprBuilder<R>) this;
135         }
136 
137         @SuppressWarnings("unchecked")
138         public <R> StreamExprBuilder<R> flatMap(QuotableFunction<T, Iterable<R>> f) {
139             streamOps.add(new FlatMapStreamOp(f));
140             return (StreamExprBuilder<R>) this;
141         }
142 
143         public StreamExprBuilder<T> filter(QuotablePredicate<T> f) {
144             streamOps.add(new FilterStreamOp(f));
145             return this;
146         }
147 
148         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
149             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
150         }
151 
152         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
153                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
154             if (i == streamOps.size()) {
155                 terminalConsumer.accept(body, element);
156                 return;
157             }
158 
159             StreamOp sop = streamOps.get(i);
160             if (sop instanceof MapStreamOp) {
161                 body.inline(sop.op(), List.of(element), (block, value) -> {
162                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
163                 });
164             } else if (sop instanceof FilterStreamOp) {
165                 body.inline(sop.op(), List.of(element), (block, p) -> {
166                     Block.Builder _if = block.block();
167                     Block.Builder _else = continueBlock;
168                     if (continueBlock == null) {
169                         _else = block.block();
170                         _else.op(_continue());
171                     }
172 
173                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
174 
175                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
176                 });
177             } else if (sop instanceof FlatMapStreamOp) {
178                 body.inline(sop.op(), List.of(element), (block, iterable) -> {
179                     JavaEnhancedForOp forOp = enhancedFor(block.parentBody(),
180                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
181                             .expression(b -> {
182                                 b.op(_yield(iterable));
183                             })
184                             .definition(b -> {
185                                 b.op(_yield(b.parameters().get(0)));
186                             })
187                             .body(b -> {
188                                 fuseIntermediateOperation(i + 1,
189                                         b,
190                                         b.parameters().get(0),
191                                         null, terminalConsumer);
192                             });
193 
194                     block.op(forOp);
195                     block.op(_continue());
196                 });
197             }
198         }
199 
200         public FuncOp forEach(QuotableConsumer<T> quotableConsumer) {
201             if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) {
202                 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
203             }
204             if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) {
205                 throw new IllegalArgumentException("Quotable consumer captures values");
206             }
207 
208             return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
209                     .body(b -> {
210                         Value source = b.parameters().get(0);
211 
212                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
213                                 .apply(loopBlock -> {
214                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
215                                         terminalBlock.inline(consumer, List.of(resultValue),
216                                                 (_, _) -> {
217                                                 });
218                                         terminalBlock.op(_continue());
219                                     });
220 
221                                 });
222                         b.op(sourceLoop);
223                         b.op(_return());
224                     });
225         }
226 
227         public <C> FuncOp collect(QuotableSupplier<C> quotableSupplier, QuotableBiConsumer<C, T> quotableAccumulator) {
228             if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) {
229                 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
230             }
231             if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) {
232                 throw new IllegalArgumentException("Quotable supplier captures values");
233             }
234             if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) {
235                 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
236             }
237             if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) {
238                 throw new IllegalArgumentException("Quotable accumulator captures values");
239             }
240 
241             JavaType collectType = (JavaType) supplier.invokableType().returnType();
242             return func("fused.collect", CoreType.functionType(collectType, sourceType))
243                     .body(b -> {
244                         Value source = b.parameters().get(0);
245 
246                         b.inline(supplier, List.of(), (block, collect) -> {
247                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
248                                     .apply(loopBlock -> {
249                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
250                                             terminalBlock.inline(accumulator, List.of(collect, resultValue),
251                                                     (_, _) -> {
252                                                     });
253                                             terminalBlock.op(_continue());
254                                         });
255                                     });
256                             block.op(sourceLoop);
257                             block.op(_return(collect));
258                         });
259                     });
260         }
261 
262     }
263 }
264