1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.*;
 25 import jdk.incubator.code.dialect.core.Inliner;
 26 import jdk.incubator.code.dialect.core.CoreType;
 27 import jdk.incubator.code.dialect.java.JavaOp;
 28 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
 29 import jdk.incubator.code.dialect.java.ClassType;
 30 import jdk.incubator.code.dialect.java.JavaType;
 31 import java.util.ArrayList;
 32 import java.util.List;
 33 import java.util.function.*;
 34 
 35 import static jdk.incubator.code.dialect.core.CoreOp.*;
 36 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
 37 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
 38 import static jdk.incubator.code.dialect.java.JavaType.type;
 39 
 40 public final class StreamFuser {
 41 
 42     StreamFuser() {}
 43 
 44     public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
 45         JavaType elementType = type(elementClass);
 46         // java.util.List<E>
 47         JavaType listType = parameterized(type(List.class), elementType);
 48         return new StreamExprBuilder<>(listType, elementType,
 49                 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
 50     }
 51 
 52     public static class StreamExprBuilder<T> {
 53         static class StreamOp {
 54             final JavaOp.LambdaOp lambdaOp;
 55 
 56             StreamOp(Object reflectableLambda) {
 57                 Quoted<JavaOp.LambdaOp> quotedLambda = Op.ofLambda(reflectableLambda).orElseThrow();
 58                 if (!(quotedLambda.capturedValues().isEmpty())) {
 59                     throw new IllegalArgumentException("Reflectable lambda captures values");
 60                 }
 61                 this.lambdaOp = quotedLambda.op();
 62             }
 63 
 64             JavaOp.LambdaOp op() {
 65                 return lambdaOp;
 66             }
 67         }
 68 
 69         static class MapStreamOp extends StreamOp {
 70             public MapStreamOp(Object reflectableLambda) {
 71                 super(reflectableLambda);
 72             }
 73         }
 74 
 75         static class FlatMapStreamOp extends StreamOp {
 76             public FlatMapStreamOp(Object reflectableLambda) {
 77                 super(reflectableLambda);
 78             }
 79         }
 80 
 81         static class FilterStreamOp extends StreamOp {
 82             public FilterStreamOp(Object reflectableLambda) {
 83                 super(reflectableLambda);
 84             }
 85         }
 86 
 87         final JavaType sourceType;
 88         final JavaType sourceElementType;
 89         final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
 90         final List<StreamOp> streamOps;
 91 
 92         StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
 93                           BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
 94             this.sourceType = sourceType;
 95             this.sourceElementType = sourceElementType;
 96             this.loopSupplier = loopSupplier;
 97             this.streamOps = new ArrayList<>();
 98         }
 99 
100         static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
101                                                          Value iterable) {
102             return enhancedFor(ancestorBody, iterable.type(), elementType)
103                     .expression(b -> {
104                         b.op(core_yield(iterable));
105                     })
106                     .definition(b -> {
107                         b.op(core_yield(b.parameters().get(0)));
108                     });
109         }
110 
111         @SuppressWarnings("unchecked")
112         public <R> StreamExprBuilder<R> map(Function<T, R> f) {
113             streamOps.add(new MapStreamOp(f));
114             return (StreamExprBuilder<R>) this;
115         }
116 
117         @SuppressWarnings("unchecked")
118         public <R> StreamExprBuilder<R> flatMap(Function<T, Iterable<R>> f) {
119             streamOps.add(new FlatMapStreamOp(f));
120             return (StreamExprBuilder<R>) this;
121         }
122 
123         public StreamExprBuilder<T> filter(Predicate<T> f) {
124             streamOps.add(new FilterStreamOp(f));
125             return this;
126         }
127 
128         void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
129             fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
130         }
131 
132         void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
133                                        BiConsumer<Block.Builder, Value> terminalConsumer) {
134             if (i == streamOps.size()) {
135                 terminalConsumer.accept(body, element);
136                 return;
137             }
138 
139             StreamOp sop = streamOps.get(i);
140             if (sop instanceof MapStreamOp) {
141                 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
142                     fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
143                 });
144             } else if (sop instanceof FilterStreamOp) {
145                 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
146                     Block.Builder _if = block.block();
147                     Block.Builder _else = continueBlock;
148                     if (continueBlock == null) {
149                         _else = block.block();
150                         _else.op(JavaOp.continue_());
151                     }
152 
153                     block.op(conditionalBranch(p, _if.successor(), _else.successor()));
154 
155                     fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
156                 });
157             } else if (sop instanceof FlatMapStreamOp) {
158                 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
159                     EnhancedForOp forOp = enhancedFor(block.parentBody(),
160                             iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
161                             .expression(b -> {
162                                 b.op(core_yield(iterable));
163                             })
164                             .definition(b -> {
165                                 b.op(core_yield(b.parameters().get(0)));
166                             })
167                             .body(b -> {
168                                 fuseIntermediateOperation(i + 1,
169                                         b,
170                                         b.parameters().get(0),
171                                         null, terminalConsumer);
172                             });
173 
174                     block.op(forOp);
175                     block.op(JavaOp.continue_());
176                 });
177             }
178         }
179 
180         public FuncOp forEach(Consumer<T> reflectableConsumer) {
181             Quoted<JavaOp.LambdaOp> quotedConsumer = Op.ofLambda(reflectableConsumer).orElseThrow();
182             if (!(quotedConsumer.capturedValues().isEmpty())) {
183                 throw new IllegalArgumentException("Reflectable consumer captures values");
184             }
185             JavaOp.LambdaOp consumer = quotedConsumer.op();
186 
187             return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
188                     .body(b -> {
189                         Value source = b.parameters().get(0);
190 
191                         Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
192                                 .apply(loopBlock -> {
193                                     fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
194                                         Inliner.inline(terminalBlock, consumer, List.of(resultValue),
195                                                 (_, _) -> {
196                                                 });
197                                         terminalBlock.op(JavaOp.continue_());
198                                     });
199 
200                                 });
201                         b.op(sourceLoop);
202                         b.op(return_());
203                     });
204         }
205 
206         public <C> FuncOp collect(Supplier<C> reflectableSupplier, BiConsumer<C, T> reflectableAccumulator) {
207             Quoted<JavaOp.LambdaOp> quotedSupplier = Op.ofLambda(reflectableSupplier).orElseThrow();
208             if (!(quotedSupplier.capturedValues().isEmpty())) {
209                 throw new IllegalArgumentException("Reflectable supplier captures values");
210             }
211             JavaOp.LambdaOp supplier = quotedSupplier.op();
212             Quoted<JavaOp.LambdaOp> quotedAccumulator = Op.ofLambda(reflectableAccumulator).orElseThrow();
213             if (!(quotedAccumulator.capturedValues().isEmpty())) {
214                 throw new IllegalArgumentException("Reflectable accumulator captures values");
215             }
216             JavaOp.LambdaOp accumulator = quotedAccumulator.op();
217 
218             JavaType collectType = (JavaType) supplier.invokableType().returnType();
219             return func("fused.collect", CoreType.functionType(collectType, sourceType))
220                     .body(b -> {
221                         Value source = b.parameters().get(0);
222 
223                         Inliner.inline(b, supplier, List.of(), (block, collect) -> {
224                             Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
225                                     .apply(loopBlock -> {
226                                         fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
227                                             Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
228                                                     (_, _) -> {
229                                                     });
230                                             terminalBlock.op(JavaOp.continue_());
231                                         });
232                                     });
233                             block.op(sourceLoop);
234                             block.op(return_(collect));
235                         });
236                     });
237         }
238 
239     }
240 }
241