1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.op.ExtendedOp.JavaEnhancedForOp;
 26 import jdk.incubator.code.type.ClassType;
 27 import jdk.incubator.code.type.FunctionType;
 28 import jdk.incubator.code.type.JavaType;
 29 import java.util.ArrayList;
 30 import java.util.List;
 31 import java.util.function.*;
 32 
 33 import static jdk.incubator.code.op.CoreOp.*;
 34 import static jdk.incubator.code.op.ExtendedOp._continue;
 35 import static jdk.incubator.code.op.ExtendedOp.enhancedFor;
 36 import static jdk.incubator.code.type.JavaType.parameterized;
 37 import static jdk.incubator.code.type.JavaType.type;
 38 
 39 public final class StreamFuserUsingQuotable {
 40 
 41     // Quotable functional interfaces
 42 
 43     public interface QuotablePredicate<T> extends Quotable, Predicate<T> {
 44     }
 45 
 46     public interface QuotableFunction<T, R> extends Quotable, Function<T, R> {
 47     }
 48 
 49     public interface QuotableSupplier<T> extends Quotable, Supplier<T> {
 50     }
 51 
 52     public interface QuotableConsumer<T> extends Quotable, Consumer<T> {
 53     }
 54 
 55     public interface QuotableBiConsumer<T, U> extends Quotable, BiConsumer<T, U> {
 56     }
 57 
 58 
 59     StreamFuserUsingQuotable() {}
 60 
 61     public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
 62         JavaType elementType = type(elementClass);
 63         // java.util.List<E>
 64         JavaType listType = parameterized(type(List.class), elementType);
 65         return new StreamExprBuilder<>(listType, elementType,
 66                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 67     }
 68 
 69     public static class StreamExprBuilder<T> {
 70         static class StreamOp {
 71             final LambdaOp lambdaOp;
 72 
 73             StreamOp(Quotable quotedLambda) {
 74                 if (!(quotedLambda.quoted().op() instanceof LambdaOp lambdaOp)) {
 75                     throw new IllegalArgumentException("Quotable operation is not lambda operation");
 76                 }
 77                 if (!(quotedLambda.quoted().capturedValues().isEmpty())) {
 78                     throw new IllegalArgumentException("Quotable operation captures values");
 79                 }
 80                 this.lambdaOp = lambdaOp;
 81             }
 82 
 83             LambdaOp op() {
 84                 return lambdaOp;
 85             }
 86         }
 87 
 88         static class MapStreamOp extends StreamOp {
 89             public MapStreamOp(Quotable quotedLambda) {
 90                 super(quotedLambda);
 91             }
 92         }
 93 
 94         static class FlatMapStreamOp extends StreamOp {
 95             public FlatMapStreamOp(Quotable quotedLambda) {
 96                 super(quotedLambda);
 97             }
 98         }
 99 
100         static class FilterStreamOp extends StreamOp {
101             public FilterStreamOp(Quotable quotedLambda) {
102                 super(quotedLambda);
103             }
104         }
105 
106         final JavaType sourceType;
107         final JavaType sourceElementType;
108         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
109         final List<StreamOp> streamOps;
110 
111         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
112                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
113             this.sourceType = sourceType;
114             this.sourceElementType = sourceElementType;
115             this.loopSupplier = loopSupplier;
116             this.streamOps = new ArrayList<>();
117         }
118 
119         static JavaEnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
120                                                              Value iterable) {
121             return enhancedFor(ancestorBody, iterable.type(), elementType)
122                     .expression(b -> {
123                         b.op(_yield(iterable));
124                     })
125                     .definition(b -> {
126                         b.op(_yield(b.parameters().get(0)));
127                     });
128         }
129 
130         @SuppressWarnings("unchecked")
131         public <R> StreamExprBuilder<R> map(QuotableFunction<T, R> f) {
132             streamOps.add(new MapStreamOp(f));
133             return (StreamExprBuilder<R>) this;
134         }
135 
136         @SuppressWarnings("unchecked")
137         public <R> StreamExprBuilder<R> flatMap(QuotableFunction<T, Iterable<R>> f) {
138             streamOps.add(new FlatMapStreamOp(f));
139             return (StreamExprBuilder<R>) this;
140         }
141 
142         public StreamExprBuilder<T> filter(QuotablePredicate<T> f) {
143             streamOps.add(new FilterStreamOp(f));
144             return this;
145         }
146 
147         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
148             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
149         }
150 
151         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
152                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
153             if (i == streamOps.size()) {
154                 terminalConsumer.accept(body, element);
155                 return;
156             }
157 
158             StreamOp sop = streamOps.get(i);
159             if (sop instanceof MapStreamOp) {
160                 body.inline(sop.op(), List.of(element), (block, value) -> {
161                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
162                 });
163             } else if (sop instanceof FilterStreamOp) {
164                 body.inline(sop.op(), List.of(element), (block, p) -> {
165                     Block.Builder _if = block.block();
166                     Block.Builder _else = continueBlock;
167                     if (continueBlock == null) {
168                         _else = block.block();
169                         _else.op(_continue());
170                     }
171 
172                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
173 
174                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
175                 });
176             } else if (sop instanceof FlatMapStreamOp) {
177                 body.inline(sop.op(), List.of(element), (block, iterable) -> {
178                     JavaEnhancedForOp forOp = enhancedFor(block.parentBody(),
179                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
180                             .expression(b -> {
181                                 b.op(_yield(iterable));
182                             })
183                             .definition(b -> {
184                                 b.op(_yield(b.parameters().get(0)));
185                             })
186                             .body(b -> {
187                                 fuseIntermediateOperation(i + 1,
188                                         b,
189                                         b.parameters().get(0),
190                                         null, terminalConsumer);
191                             });
192 
193                     block.op(forOp);
194                     block.op(_continue());
195                 });
196             }
197         }
198 
199         public FuncOp forEach(QuotableConsumer<T> quotableConsumer) {
200             if (!(quotableConsumer.quoted().op() instanceof LambdaOp consumer)) {
201                 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
202             }
203             if (!(quotableConsumer.quoted().capturedValues().isEmpty())) {
204                 throw new IllegalArgumentException("Quotable consumer captures values");
205             }
206 
207             return func("fused.forEach", FunctionType.functionType(JavaType.VOID, sourceType))
208                     .body(b -> {
209                         Value source = b.parameters().get(0);
210 
211                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
212                                 .apply(loopBlock -> {
213                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
214                                         terminalBlock.inline(consumer, List.of(resultValue),
215                                                 (_, _) -> {
216                                                 });
217                                         terminalBlock.op(_continue());
218                                     });
219 
220                                 });
221                         b.op(sourceLoop);
222                         b.op(_return());
223                     });
224         }
225 
226         public <C> FuncOp collect(QuotableSupplier<C> quotableSupplier, QuotableBiConsumer<C, T> quotableAccumulator) {
227             if (!(quotableSupplier.quoted().op() instanceof LambdaOp supplier)) {
228                 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
229             }
230             if (!(quotableSupplier.quoted().capturedValues().isEmpty())) {
231                 throw new IllegalArgumentException("Quotable supplier captures values");
232             }
233             if (!(quotableAccumulator.quoted().op() instanceof LambdaOp accumulator)) {
234                 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
235             }
236             if (!(quotableAccumulator.quoted().capturedValues().isEmpty())) {
237                 throw new IllegalArgumentException("Quotable accumulator captures values");
238             }
239 
240             JavaType collectType = (JavaType) supplier.invokableType().returnType();
241             return func("fused.collect", FunctionType.functionType(collectType, sourceType))
242                     .body(b -> {
243                         Value source = b.parameters().get(0);
244 
245                         b.inline(supplier, List.of(), (block, collect) -> {
246                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
247                                     .apply(loopBlock -> {
248                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
249                                             terminalBlock.inline(accumulator, List.of(collect, resultValue),
250                                                     (_, _) -> {
251                                                     });
252                                             terminalBlock.op(_continue());
253                                         });
254                                     });
255                             block.op(sourceLoop);
256                             block.op(_return(collect));
257                         });
258                     });
259         }
260 
261     }
262 }
263