1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.Reflect;
26 import jdk.incubator.code.analysis.Inliner;
27 import jdk.incubator.code.dialect.core.CoreType;
28 import jdk.incubator.code.dialect.java.JavaOp;
29 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
30 import jdk.incubator.code.dialect.java.ClassType;
31 import jdk.incubator.code.dialect.java.JavaType;
32 import java.util.ArrayList;
33 import java.util.List;
34 import java.util.function.*;
35
36 import static jdk.incubator.code.dialect.core.CoreOp.*;
37 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
38 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
39 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
40 import static jdk.incubator.code.dialect.java.JavaType.type;
41
42 public final class StreamFuser {
43
44 StreamFuser() {}
45
46 public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
47 JavaType elementType = type(elementClass);
48 // java.util.List<E>
49 JavaType listType = parameterized(type(List.class), elementType);
50 return new StreamExprBuilder<>(listType, elementType,
51 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
52 }
53
54 public static class StreamExprBuilder<T> {
55 static class StreamOp {
56 final JavaOp.LambdaOp lambdaOp;
57
58 StreamOp(Object quotedLambda) {
59 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) {
60 throw new IllegalArgumentException("Quotable operation is not lambda operation");
61 }
62 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) {
63 throw new IllegalArgumentException("Quotable operation captures values");
64 }
65 this.lambdaOp = lambdaOp;
66 }
67
68 JavaOp.LambdaOp op() {
69 return lambdaOp;
70 }
71 }
72
73 static class MapStreamOp extends StreamOp {
74 public MapStreamOp(Object quotedLambda) {
75 super(quotedLambda);
76 }
77 }
78
79 static class FlatMapStreamOp extends StreamOp {
80 public FlatMapStreamOp(Object quotedLambda) {
81 super(quotedLambda);
82 }
83 }
84
85 static class FilterStreamOp extends StreamOp {
86 public FilterStreamOp(Object quotedLambda) {
87 super(quotedLambda);
88 }
89 }
90
91 final JavaType sourceType;
92 final JavaType sourceElementType;
93 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
94 final List<StreamOp> streamOps;
95
96 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
97 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
98 this.sourceType = sourceType;
99 this.sourceElementType = sourceElementType;
100 this.loopSupplier = loopSupplier;
101 this.streamOps = new ArrayList<>();
102 }
103
104 static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
105 Value iterable) {
106 return enhancedFor(ancestorBody, iterable.type(), elementType)
107 .expression(b -> {
108 b.op(core_yield(iterable));
109 })
110 .definition(b -> {
111 b.op(core_yield(b.parameters().get(0)));
112 });
113 }
114
115 @SuppressWarnings("unchecked")
116 public <R> StreamExprBuilder<R> map(Function<T, R> f) {
117 streamOps.add(new MapStreamOp(f));
118 return (StreamExprBuilder<R>) this;
119 }
120
121 @SuppressWarnings("unchecked")
122 public <R> StreamExprBuilder<R> flatMap(Function<T, Iterable<R>> f) {
123 streamOps.add(new FlatMapStreamOp(f));
124 return (StreamExprBuilder<R>) this;
125 }
126
127 public StreamExprBuilder<T> filter(Predicate<T> f) {
128 streamOps.add(new FilterStreamOp(f));
129 return this;
130 }
131
132 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
133 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
134 }
135
136 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
137 BiConsumer<Block.Builder, Value> terminalConsumer) {
138 if (i == streamOps.size()) {
139 terminalConsumer.accept(body, element);
140 return;
141 }
142
143 StreamOp sop = streamOps.get(i);
144 if (sop instanceof MapStreamOp) {
145 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
146 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
147 });
148 } else if (sop instanceof FilterStreamOp) {
149 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
150 Block.Builder _if = block.block();
151 Block.Builder _else = continueBlock;
152 if (continueBlock == null) {
153 _else = block.block();
154 _else.op(JavaOp.continue_());
155 }
156
157 block.op(conditionalBranch(p, _if.successor(), _else.successor()));
158
159 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
160 });
161 } else if (sop instanceof FlatMapStreamOp) {
162 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
163 EnhancedForOp forOp = enhancedFor(block.parentBody(),
164 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
165 .expression(b -> {
166 b.op(core_yield(iterable));
167 })
168 .definition(b -> {
169 b.op(core_yield(b.parameters().get(0)));
170 })
171 .body(b -> {
172 fuseIntermediateOperation(i + 1,
173 b,
174 b.parameters().get(0),
175 null, terminalConsumer);
176 });
177
178 block.op(forOp);
179 block.op(JavaOp.continue_());
180 });
181 }
182 }
183
184 public FuncOp forEach(Consumer<T> quotableConsumer) {
185 if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) {
186 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
187 }
188 if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) {
189 throw new IllegalArgumentException("Quotable consumer captures values");
190 }
191
192 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
193 .body(b -> {
194 Value source = b.parameters().get(0);
195
196 Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
197 .apply(loopBlock -> {
198 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
199 Inliner.inline(terminalBlock, consumer, List.of(resultValue),
200 (_, _) -> {
201 });
202 terminalBlock.op(JavaOp.continue_());
203 });
204
205 });
206 b.op(sourceLoop);
207 b.op(return_());
208 });
209 }
210
211 public <C> FuncOp collect(Supplier<C> quotableSupplier, BiConsumer<C, T> quotableAccumulator) {
212 if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) {
213 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
214 }
215 if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) {
216 throw new IllegalArgumentException("Quotable supplier captures values");
217 }
218 if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) {
219 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
220 }
221 if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) {
222 throw new IllegalArgumentException("Quotable accumulator captures values");
223 }
224
225 JavaType collectType = (JavaType) supplier.invokableType().returnType();
226 return func("fused.collect", CoreType.functionType(collectType, sourceType))
227 .body(b -> {
228 Value source = b.parameters().get(0);
229
230 Inliner.inline(b, supplier, List.of(), (block, collect) -> {
231 Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
232 .apply(loopBlock -> {
233 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
234 Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
235 (_, _) -> {
236 });
237 terminalBlock.op(JavaOp.continue_());
238 });
239 });
240 block.op(sourceLoop);
241 block.op(return_(collect));
242 });
243 });
244 }
245
246 }
247 }
248