1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.op.CoreOp;
 26 import jdk.incubator.code.op.ExtendedOp.JavaEnhancedForOp;
 27 import jdk.incubator.code.type.ClassType;
 28 import jdk.incubator.code.type.FunctionType;
 29 import jdk.incubator.code.type.JavaType;
 30 import java.util.ArrayList;
 31 import java.util.List;
 32 import java.util.function.BiConsumer;
 33 import java.util.function.BiFunction;
 34 import java.util.function.Consumer;
 35 import java.util.function.Function;
 36 
 37 import static jdk.incubator.code.op.CoreOp.*;
 38 import static jdk.incubator.code.op.ExtendedOp._continue;
 39 import static jdk.incubator.code.op.ExtendedOp.enhancedFor;
 40 import static jdk.incubator.code.type.JavaType.parameterized;
 41 import static jdk.incubator.code.type.JavaType.type;
 42 
 43 public final class StreamFuser {
 44 
 45     StreamFuser() {}
 46 
 47     public static StreamExprBuilder fromList(JavaType elementType) {
 48         // java.util.List<E>
 49         JavaType listType = parameterized(type(List.class), elementType);
 50         return new StreamExprBuilder(listType, elementType,
 51                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 52     }
 53 
 54     public static class StreamExprBuilder {
 55         static class StreamOp {
 56             final Quoted quotedClosure;
 57 
 58             StreamOp(Quoted quotedClosure) {
 59                 if (!(quotedClosure.op() instanceof CoreOp.ClosureOp)) {
 60                     throw new IllegalArgumentException("Quoted operation is not closure operation");
 61                 }
 62                 this.quotedClosure = quotedClosure;
 63             }
 64 
 65             CoreOp.ClosureOp op() {
 66                 return (CoreOp.ClosureOp) quotedClosure.op();
 67             }
 68         }
 69 
 70         static class MapStreamOp extends StreamOp {
 71             public MapStreamOp(Quoted quotedClosure) {
 72                 super(quotedClosure);
 73                 // @@@ Check closure signature
 74             }
 75         }
 76 
 77         static class FlatMapStreamOp extends StreamOp {
 78             public FlatMapStreamOp(Quoted quotedClosure) {
 79                 super(quotedClosure);
 80                 // @@@ Check closure signature
 81             }
 82         }
 83 
 84         static class FilterStreamOp extends StreamOp {
 85             public FilterStreamOp(Quoted quotedClosure) {
 86                 super(quotedClosure);
 87                 // @@@ Check closure signature
 88             }
 89         }
 90 
 91         final JavaType sourceType;
 92         final JavaType sourceElementType;
 93         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
 94         final List<StreamOp> streamOps;
 95 
 96         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
 97                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
 98             this.sourceType = sourceType;
 99             this.sourceElementType = sourceElementType;
100             this.loopSupplier = loopSupplier;
101             this.streamOps = new ArrayList<>();
102         }
103 
104         static JavaEnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
105                                                              Value iterable) {
106             return enhancedFor(ancestorBody, iterable.type(), elementType)
107                     .expression(b -> {
108                         b.op(_yield(iterable));
109                     })
110                     .definition(b -> {
111                         b.op(_yield(b.parameters().get(0)));
112                     });
113         }
114 
115         public StreamExprBuilder map(Quoted f) {
116             streamOps.add(new MapStreamOp(f));
117             return this;
118         }
119 
120         public StreamExprBuilder flatMap(Quoted f) {
121             streamOps.add(new FlatMapStreamOp(f));
122             return this;
123         }
124 
125         public StreamExprBuilder filter(Quoted f) {
126             streamOps.add(new FilterStreamOp(f));
127             return this;
128         }
129 
130         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
131             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
132         }
133 
134         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
135                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
136             if (i == streamOps.size()) {
137                 terminalConsumer.accept(body, element);
138                 return;
139             }
140 
141             StreamOp sop = streamOps.get(i);
142             if (sop instanceof MapStreamOp) {
143                 body.inline(sop.op(), List.of(element), (block, value) -> {
144                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
145                 });
146             } else if (sop instanceof FilterStreamOp) {
147                 body.inline(sop.op(), List.of(element), (block, p) -> {
148                     Block.Builder _if = block.block();
149                     Block.Builder _else = continueBlock;
150                     if (continueBlock == null) {
151                         _else = block.block();
152                         _else.op(_continue());
153                     }
154 
155                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
156 
157                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
158                 });
159             } else if (sop instanceof FlatMapStreamOp) {
160                 body.inline(sop.op(), List.of(element), (block, iterable) -> {
161                     JavaEnhancedForOp forOp = enhancedFor(block.parentBody(),
162                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
163                             .expression(b -> {
164                                 b.op(_yield(iterable));
165                             })
166                             .definition(b -> {
167                                 b.op(_yield(b.parameters().get(0)));
168                             })
169                             .body(b -> {
170                                 fuseIntermediateOperation(i + 1,
171                                         b,
172                                         b.parameters().get(0),
173                                         null, terminalConsumer);
174                             });
175 
176                     block.op(forOp);
177                     block.op(_continue());
178                 });
179             }
180         }
181 
182         public FuncOp forEach(Quoted quotedConsumer) {
183             if (!(quotedConsumer.op() instanceof CoreOp.ClosureOp consumer)) {
184                 throw new IllegalArgumentException("Quoted consumer is not closure operation");
185             }
186 
187             return func("fused.forEach", FunctionType.functionType(JavaType.VOID, sourceType))
188                     .body(b -> {
189                         Value source = b.parameters().get(0);
190 
191                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
192                                 .apply(loopBlock -> {
193                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
194                                         terminalBlock.inline(consumer, List.of(resultValue),
195                                                 (_, _) -> {
196                                                 });
197                                         terminalBlock.op(_continue());
198                                     });
199 
200                                 });
201                         b.op(sourceLoop);
202                         b.op(_return());
203                     });
204         }
205 
206         // Supplier<C> supplier, BiConsumer<C, T> accumulator
207         public FuncOp collect(Quoted quotedSupplier, Quoted quotedAccumulator) {
208             if (!(quotedSupplier.op() instanceof CoreOp.ClosureOp supplier)) {
209                 throw new IllegalArgumentException("Quoted supplier is not closure operation");
210             }
211             if (!(quotedAccumulator.op() instanceof CoreOp.ClosureOp accumulator)) {
212                 throw new IllegalArgumentException("Quoted accumulator is not closure operation");
213             }
214 
215             JavaType collectType = (JavaType) supplier.invokableType().returnType();
216             return func("fused.collect", FunctionType.functionType(collectType, sourceType))
217                     .body(b -> {
218                         Value source = b.parameters().get(0);
219 
220                         b.inline(supplier, List.of(), (block, collect) -> {
221                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
222                                     .apply(loopBlock -> {
223                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
224                                             terminalBlock.inline(accumulator, List.of(collect, resultValue),
225                                                     (_, _) -> {
226                                                     });
227                                             terminalBlock.op(_continue());
228                                         });
229                                     });
230                             block.op(sourceLoop);
231                             block.op(_return(collect));
232                         });
233                     });
234         }
235 
236     }
237 }