1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.analysis.Inliner;
 26 import jdk.incubator.code.dialect.core.CoreOp;
 27 import jdk.incubator.code.dialect.core.CoreType;
 28 import jdk.incubator.code.dialect.java.JavaOp;
 29 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
 30 import jdk.incubator.code.dialect.java.ClassType;
 31 import jdk.incubator.code.dialect.java.JavaType;
 32 import java.util.ArrayList;
 33 import java.util.List;
 34 import java.util.function.BiConsumer;
 35 import java.util.function.BiFunction;
 36 import java.util.function.Consumer;
 37 import java.util.function.Function;
 38 
 39 import static jdk.incubator.code.dialect.core.CoreOp.*;
 40 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
 41 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
 42 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
 43 import static jdk.incubator.code.dialect.java.JavaType.type;
 44 
 45 public final class StreamFuser {
 46 
 47     StreamFuser() {}
 48 
 49     public static StreamExprBuilder fromList(JavaType elementType) {
 50         // java.util.List<E>
 51         JavaType listType = parameterized(type(List.class), elementType);
 52         return new StreamExprBuilder(listType, elementType,
 53                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 54     }
 55 
 56     public static class StreamExprBuilder {
 57         static class StreamOp {
 58             final Quoted quotedClosure;
 59 
 60             StreamOp(Quoted quotedClosure) {
 61                 if (!(quotedClosure.op() instanceof CoreOp.ClosureOp)) {
 62                     throw new IllegalArgumentException("Quoted operation is not closure operation");
 63                 }
 64                 this.quotedClosure = quotedClosure;
 65             }
 66 
 67             CoreOp.ClosureOp op() {
 68                 return (CoreOp.ClosureOp) quotedClosure.op();
 69             }
 70         }
 71 
 72         static class MapStreamOp extends StreamOp {
 73             public MapStreamOp(Quoted quotedClosure) {
 74                 super(quotedClosure);
 75                 // @@@ Check closure signature
 76             }
 77         }
 78 
 79         static class FlatMapStreamOp extends StreamOp {
 80             public FlatMapStreamOp(Quoted quotedClosure) {
 81                 super(quotedClosure);
 82                 // @@@ Check closure signature
 83             }
 84         }
 85 
 86         static class FilterStreamOp extends StreamOp {
 87             public FilterStreamOp(Quoted quotedClosure) {
 88                 super(quotedClosure);
 89                 // @@@ Check closure signature
 90             }
 91         }
 92 
 93         final JavaType sourceType;
 94         final JavaType sourceElementType;
 95         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
 96         final List<StreamOp> streamOps;
 97 
 98         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
 99                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
100             this.sourceType = sourceType;
101             this.sourceElementType = sourceElementType;
102             this.loopSupplier = loopSupplier;
103             this.streamOps = new ArrayList<>();
104         }
105 
106         static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
107                                                          Value iterable) {
108             return enhancedFor(ancestorBody, iterable.type(), elementType)
109                     .expression(b -> {
110                         b.op(core_yield(iterable));
111                     })
112                     .definition(b -> {
113                         b.op(core_yield(b.parameters().get(0)));
114                     });
115         }
116 
117         public StreamExprBuilder map(Quoted f) {
118             streamOps.add(new MapStreamOp(f));
119             return this;
120         }
121 
122         public StreamExprBuilder flatMap(Quoted f) {
123             streamOps.add(new FlatMapStreamOp(f));
124             return this;
125         }
126 
127         public StreamExprBuilder filter(Quoted f) {
128             streamOps.add(new FilterStreamOp(f));
129             return this;
130         }
131 
132         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
133             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
134         }
135 
136         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
137                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
138             if (i == streamOps.size()) {
139                 terminalConsumer.accept(body, element);
140                 return;
141             }
142 
143             StreamOp sop = streamOps.get(i);
144             if (sop instanceof MapStreamOp) {
145                 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
146                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
147                 });
148             } else if (sop instanceof FilterStreamOp) {
149                 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
150                     Block.Builder _if = block.block();
151                     Block.Builder _else = continueBlock;
152                     if (continueBlock == null) {
153                         _else = block.block();
154                         _else.op(JavaOp.continue_());
155                     }
156 
157                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
158 
159                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
160                 });
161             } else if (sop instanceof FlatMapStreamOp) {
162                 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
163                     EnhancedForOp forOp = enhancedFor(block.parentBody(),
164                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
165                             .expression(b -> {
166                                 b.op(core_yield(iterable));
167                             })
168                             .definition(b -> {
169                                 b.op(core_yield(b.parameters().get(0)));
170                             })
171                             .body(b -> {
172                                 fuseIntermediateOperation(i + 1,
173                                         b,
174                                         b.parameters().get(0),
175                                         null, terminalConsumer);
176                             });
177 
178                     block.op(forOp);
179                     block.op(JavaOp.continue_());
180                 });
181             }
182         }
183 
184         public FuncOp forEach(Quoted quotedConsumer) {
185             if (!(quotedConsumer.op() instanceof CoreOp.ClosureOp consumer)) {
186                 throw new IllegalArgumentException("Quoted consumer is not closure operation");
187             }
188 
189             return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
190                     .body(b -> {
191                         Value source = b.parameters().get(0);
192 
193                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
194                                 .apply(loopBlock -> {
195                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
196                                         Inliner.inline(terminalBlock, consumer, List.of(resultValue),
197                                                 (_, _) -> {
198                                                 });
199                                         terminalBlock.op(JavaOp.continue_());
200                                     });
201 
202                                 });
203                         b.op(sourceLoop);
204                         b.op(return_());
205                     });
206         }
207 
208         // Supplier<C> supplier, BiConsumer<C, T> accumulator
209         public FuncOp collect(Quoted quotedSupplier, Quoted quotedAccumulator) {
210             if (!(quotedSupplier.op() instanceof CoreOp.ClosureOp supplier)) {
211                 throw new IllegalArgumentException("Quoted supplier is not closure operation");
212             }
213             if (!(quotedAccumulator.op() instanceof CoreOp.ClosureOp accumulator)) {
214                 throw new IllegalArgumentException("Quoted accumulator is not closure operation");
215             }
216 
217             JavaType collectType = (JavaType) supplier.invokableType().returnType();
218             return func("fused.collect", CoreType.functionType(collectType, sourceType))
219                     .body(b -> {
220                         Value source = b.parameters().get(0);
221 
222                         Inliner.inline(b, supplier, List.of(), (block, collect) -> {
223                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
224                                     .apply(loopBlock -> {
225                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
226                                             Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
227                                                     (_, _) -> {
228                                                     });
229                                             terminalBlock.op(JavaOp.continue_());
230                                         });
231                                     });
232                             block.op(sourceLoop);
233                             block.op(return_(collect));
234                         });
235                     });
236         }
237 
238     }
239 }