1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.analysis.Inliner;
26 import jdk.incubator.code.dialect.core.CoreOp;
27 import jdk.incubator.code.dialect.core.CoreType;
28 import jdk.incubator.code.dialect.java.JavaOp;
29 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
30 import jdk.incubator.code.dialect.java.ClassType;
31 import jdk.incubator.code.dialect.java.JavaType;
32 import java.util.ArrayList;
33 import java.util.List;
34 import java.util.function.BiConsumer;
35 import java.util.function.BiFunction;
36 import java.util.function.Consumer;
37 import java.util.function.Function;
38
39 import static jdk.incubator.code.dialect.core.CoreOp.*;
40 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
41 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
42 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
43 import static jdk.incubator.code.dialect.java.JavaType.type;
44
45 public final class StreamFuser {
46
47 StreamFuser() {}
48
49 public static StreamExprBuilder fromList(JavaType elementType) {
50 // java.util.List<E>
51 JavaType listType = parameterized(type(List.class), elementType);
52 return new StreamExprBuilder(listType, elementType,
53 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
54 }
55
56 public static class StreamExprBuilder {
57 static class StreamOp {
58 final Quoted quotedClosure;
59
60 StreamOp(Quoted quotedClosure) {
61 if (!(quotedClosure.op() instanceof CoreOp.ClosureOp)) {
62 throw new IllegalArgumentException("Quoted operation is not closure operation");
63 }
64 this.quotedClosure = quotedClosure;
65 }
66
67 CoreOp.ClosureOp op() {
68 return (CoreOp.ClosureOp) quotedClosure.op();
69 }
70 }
71
72 static class MapStreamOp extends StreamOp {
73 public MapStreamOp(Quoted quotedClosure) {
74 super(quotedClosure);
75 // @@@ Check closure signature
76 }
77 }
78
79 static class FlatMapStreamOp extends StreamOp {
80 public FlatMapStreamOp(Quoted quotedClosure) {
81 super(quotedClosure);
82 // @@@ Check closure signature
83 }
84 }
85
86 static class FilterStreamOp extends StreamOp {
87 public FilterStreamOp(Quoted quotedClosure) {
88 super(quotedClosure);
89 // @@@ Check closure signature
90 }
91 }
92
93 final JavaType sourceType;
94 final JavaType sourceElementType;
95 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
96 final List<StreamOp> streamOps;
97
98 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
99 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
100 this.sourceType = sourceType;
101 this.sourceElementType = sourceElementType;
102 this.loopSupplier = loopSupplier;
103 this.streamOps = new ArrayList<>();
104 }
105
106 static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
107 Value iterable) {
108 return enhancedFor(ancestorBody, iterable.type(), elementType)
109 .expression(b -> {
110 b.op(core_yield(iterable));
111 })
112 .definition(b -> {
113 b.op(core_yield(b.parameters().get(0)));
114 });
115 }
116
117 public StreamExprBuilder map(Quoted f) {
118 streamOps.add(new MapStreamOp(f));
119 return this;
120 }
121
122 public StreamExprBuilder flatMap(Quoted f) {
123 streamOps.add(new FlatMapStreamOp(f));
124 return this;
125 }
126
127 public StreamExprBuilder filter(Quoted f) {
128 streamOps.add(new FilterStreamOp(f));
129 return this;
130 }
131
132 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
133 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
134 }
135
136 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
137 BiConsumer<Block.Builder, Value> terminalConsumer) {
138 if (i == streamOps.size()) {
139 terminalConsumer.accept(body, element);
140 return;
141 }
142
143 StreamOp sop = streamOps.get(i);
144 if (sop instanceof MapStreamOp) {
145 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
146 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
147 });
148 } else if (sop instanceof FilterStreamOp) {
149 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
150 Block.Builder _if = block.block();
151 Block.Builder _else = continueBlock;
152 if (continueBlock == null) {
153 _else = block.block();
154 _else.op(JavaOp.continue_());
155 }
156
157 block.op(conditionalBranch(p, _if.successor(), _else.successor()));
158
159 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
160 });
161 } else if (sop instanceof FlatMapStreamOp) {
162 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
163 EnhancedForOp forOp = enhancedFor(block.parentBody(),
164 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
165 .expression(b -> {
166 b.op(core_yield(iterable));
167 })
168 .definition(b -> {
169 b.op(core_yield(b.parameters().get(0)));
170 })
171 .body(b -> {
172 fuseIntermediateOperation(i + 1,
173 b,
174 b.parameters().get(0),
175 null, terminalConsumer);
176 });
177
178 block.op(forOp);
179 block.op(JavaOp.continue_());
180 });
181 }
182 }
183
184 public FuncOp forEach(Quoted quotedConsumer) {
185 if (!(quotedConsumer.op() instanceof CoreOp.ClosureOp consumer)) {
186 throw new IllegalArgumentException("Quoted consumer is not closure operation");
187 }
188
189 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
190 .body(b -> {
191 Value source = b.parameters().get(0);
192
193 Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
194 .apply(loopBlock -> {
195 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
196 Inliner.inline(terminalBlock, consumer, List.of(resultValue),
197 (_, _) -> {
198 });
199 terminalBlock.op(JavaOp.continue_());
200 });
201
202 });
203 b.op(sourceLoop);
204 b.op(return_());
205 });
206 }
207
208 // Supplier<C> supplier, BiConsumer<C, T> accumulator
209 public FuncOp collect(Quoted quotedSupplier, Quoted quotedAccumulator) {
210 if (!(quotedSupplier.op() instanceof CoreOp.ClosureOp supplier)) {
211 throw new IllegalArgumentException("Quoted supplier is not closure operation");
212 }
213 if (!(quotedAccumulator.op() instanceof CoreOp.ClosureOp accumulator)) {
214 throw new IllegalArgumentException("Quoted accumulator is not closure operation");
215 }
216
217 JavaType collectType = (JavaType) supplier.invokableType().returnType();
218 return func("fused.collect", CoreType.functionType(collectType, sourceType))
219 .body(b -> {
220 Value source = b.parameters().get(0);
221
222 Inliner.inline(b, supplier, List.of(), (block, collect) -> {
223 Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
224 .apply(loopBlock -> {
225 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
226 Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
227 (_, _) -> {
228 });
229 terminalBlock.op(JavaOp.continue_());
230 });
231 });
232 block.op(sourceLoop);
233 block.op(return_(collect));
234 });
235 });
236 }
237
238 }
239 }