1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.Reflect;
 26 import jdk.incubator.code.analysis.Inliner;
 27 import jdk.incubator.code.dialect.core.CoreType;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
 30 import jdk.incubator.code.dialect.java.ClassType;
 31 import jdk.incubator.code.dialect.java.JavaType;
 32 import java.util.ArrayList;
 33 import java.util.List;
 34 import java.util.function.*;
 35 
 36 import static jdk.incubator.code.dialect.core.CoreOp.*;
 37 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
 38 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
 39 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
 40 import static jdk.incubator.code.dialect.java.JavaType.type;
 41 
 42 public final class StreamFuser {
 43 
 44     StreamFuser() {}
 45 
 46     public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
 47         JavaType elementType = type(elementClass);
 48         // java.util.List<E>
 49         JavaType listType = parameterized(type(List.class), elementType);
 50         return new StreamExprBuilder<>(listType, elementType,
 51                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 52     }
 53 
 54     public static class StreamExprBuilder<T> {
 55         static class StreamOp {
 56             final JavaOp.LambdaOp lambdaOp;
 57 
 58             StreamOp(Object quotedLambda) {
 59                 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) {
 60                     throw new IllegalArgumentException("Quotable operation is not lambda operation");
 61                 }
 62                 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) {
 63                     throw new IllegalArgumentException("Quotable operation captures values");
 64                 }
 65                 this.lambdaOp = lambdaOp;
 66             }
 67 
 68             JavaOp.LambdaOp op() {
 69                 return lambdaOp;
 70             }
 71         }
 72 
 73         static class MapStreamOp extends StreamOp {
 74             public MapStreamOp(Object quotedLambda) {
 75                 super(quotedLambda);
 76             }
 77         }
 78 
 79         static class FlatMapStreamOp extends StreamOp {
 80             public FlatMapStreamOp(Object quotedLambda) {
 81                 super(quotedLambda);
 82             }
 83         }
 84 
 85         static class FilterStreamOp extends StreamOp {
 86             public FilterStreamOp(Object quotedLambda) {
 87                 super(quotedLambda);
 88             }
 89         }
 90 
 91         final JavaType sourceType;
 92         final JavaType sourceElementType;
 93         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
 94         final List<StreamOp> streamOps;
 95 
 96         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
 97                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
 98             this.sourceType = sourceType;
 99             this.sourceElementType = sourceElementType;
100             this.loopSupplier = loopSupplier;
101             this.streamOps = new ArrayList<>();
102         }
103 
104         static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
105                                                          Value iterable) {
106             return enhancedFor(ancestorBody, iterable.type(), elementType)
107                     .expression(b -> {
108                         b.op(core_yield(iterable));
109                     })
110                     .definition(b -> {
111                         b.op(core_yield(b.parameters().get(0)));
112                     });
113         }
114 
115         @SuppressWarnings("unchecked")
116         public <R> StreamExprBuilder<R> map(Function<T, R> f) {
117             streamOps.add(new MapStreamOp(f));
118             return (StreamExprBuilder<R>) this;
119         }
120 
121         @SuppressWarnings("unchecked")
122         public <R> StreamExprBuilder<R> flatMap(Function<T, Iterable<R>> f) {
123             streamOps.add(new FlatMapStreamOp(f));
124             return (StreamExprBuilder<R>) this;
125         }
126 
127         public StreamExprBuilder<T> filter(Predicate<T> f) {
128             streamOps.add(new FilterStreamOp(f));
129             return this;
130         }
131 
132         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
133             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
134         }
135 
136         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
137                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
138             if (i == streamOps.size()) {
139                 terminalConsumer.accept(body, element);
140                 return;
141             }
142 
143             StreamOp sop = streamOps.get(i);
144             if (sop instanceof MapStreamOp) {
145                 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
146                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
147                 });
148             } else if (sop instanceof FilterStreamOp) {
149                 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
150                     Block.Builder _if = block.block();
151                     Block.Builder _else = continueBlock;
152                     if (continueBlock == null) {
153                         _else = block.block();
154                         _else.op(JavaOp.continue_());
155                     }
156 
157                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
158 
159                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
160                 });
161             } else if (sop instanceof FlatMapStreamOp) {
162                 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
163                     EnhancedForOp forOp = enhancedFor(block.parentBody(),
164                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
165                             .expression(b -> {
166                                 b.op(core_yield(iterable));
167                             })
168                             .definition(b -> {
169                                 b.op(core_yield(b.parameters().get(0)));
170                             })
171                             .body(b -> {
172                                 fuseIntermediateOperation(i + 1,
173                                         b,
174                                         b.parameters().get(0),
175                                         null, terminalConsumer);
176                             });
177 
178                     block.op(forOp);
179                     block.op(JavaOp.continue_());
180                 });
181             }
182         }
183 
184         public FuncOp forEach(Consumer<T> quotableConsumer) {
185             if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) {
186                 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
187             }
188             if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) {
189                 throw new IllegalArgumentException("Quotable consumer captures values");
190             }
191 
192             return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
193                     .body(b -> {
194                         Value source = b.parameters().get(0);
195 
196                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
197                                 .apply(loopBlock -> {
198                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
199                                         Inliner.inline(terminalBlock, consumer, List.of(resultValue),
200                                                 (_, _) -> {
201                                                 });
202                                         terminalBlock.op(JavaOp.continue_());
203                                     });
204 
205                                 });
206                         b.op(sourceLoop);
207                         b.op(return_());
208                     });
209         }
210 
211         public <C> FuncOp collect(Supplier<C> quotableSupplier, BiConsumer<C, T> quotableAccumulator) {
212             if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) {
213                 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
214             }
215             if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) {
216                 throw new IllegalArgumentException("Quotable supplier captures values");
217             }
218             if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) {
219                 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
220             }
221             if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) {
222                 throw new IllegalArgumentException("Quotable accumulator captures values");
223             }
224 
225             JavaType collectType = (JavaType) supplier.invokableType().returnType();
226             return func("fused.collect", CoreType.functionType(collectType, sourceType))
227                     .body(b -> {
228                         Value source = b.parameters().get(0);
229 
230                         Inliner.inline(b, supplier, List.of(), (block, collect) -> {
231                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
232                                     .apply(loopBlock -> {
233                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
234                                             Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
235                                                     (_, _) -> {
236                                                     });
237                                             terminalBlock.op(JavaOp.continue_());
238                                         });
239                                     });
240                             block.op(sourceLoop);
241                             block.op(return_(collect));
242                         });
243                     });
244         }
245 
246     }
247 }
248