1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.dialect.core.Inliner;
26 import jdk.incubator.code.dialect.core.CoreType;
27 import jdk.incubator.code.dialect.java.JavaOp;
28 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
29 import jdk.incubator.code.dialect.java.ClassType;
30 import jdk.incubator.code.dialect.java.JavaType;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.function.*;
34
35 import static jdk.incubator.code.dialect.core.CoreOp.*;
36 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
37 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
38 import static jdk.incubator.code.dialect.java.JavaType.type;
39
40 public final class StreamFuser {
41
42 StreamFuser() {}
43
44 public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
45 JavaType elementType = type(elementClass);
46 // java.util.List<E>
47 JavaType listType = parameterized(type(List.class), elementType);
48 return new StreamExprBuilder<>(listType, elementType,
49 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
50 }
51
52 public static class StreamExprBuilder<T> {
53 static class StreamOp {
54 final JavaOp.LambdaOp lambdaOp;
55
56 StreamOp(Object reflectableLambda) {
57 Quoted<JavaOp.LambdaOp> quotedLambda = Op.ofLambda(reflectableLambda).orElseThrow();
58 if (!(quotedLambda.capturedValues().isEmpty())) {
59 throw new IllegalArgumentException("Reflectable lambda captures values");
60 }
61 this.lambdaOp = quotedLambda.op();
62 }
63
64 JavaOp.LambdaOp op() {
65 return lambdaOp;
66 }
67 }
68
69 static class MapStreamOp extends StreamOp {
70 public MapStreamOp(Object reflectableLambda) {
71 super(reflectableLambda);
72 }
73 }
74
75 static class FlatMapStreamOp extends StreamOp {
76 public FlatMapStreamOp(Object reflectableLambda) {
77 super(reflectableLambda);
78 }
79 }
80
81 static class FilterStreamOp extends StreamOp {
82 public FilterStreamOp(Object reflectableLambda) {
83 super(reflectableLambda);
84 }
85 }
86
87 final JavaType sourceType;
88 final JavaType sourceElementType;
89 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
90 final List<StreamOp> streamOps;
91
92 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
93 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
94 this.sourceType = sourceType;
95 this.sourceElementType = sourceElementType;
96 this.loopSupplier = loopSupplier;
97 this.streamOps = new ArrayList<>();
98 }
99
100 static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
101 Value iterable) {
102 return enhancedFor(ancestorBody, iterable.type(), elementType)
103 .expression(b -> {
104 b.op(core_yield(iterable));
105 })
106 .definition(b -> {
107 b.op(core_yield(b.parameters().get(0)));
108 });
109 }
110
111 @SuppressWarnings("unchecked")
112 public <R> StreamExprBuilder<R> map(Function<T, R> f) {
113 streamOps.add(new MapStreamOp(f));
114 return (StreamExprBuilder<R>) this;
115 }
116
117 @SuppressWarnings("unchecked")
118 public <R> StreamExprBuilder<R> flatMap(Function<T, Iterable<R>> f) {
119 streamOps.add(new FlatMapStreamOp(f));
120 return (StreamExprBuilder<R>) this;
121 }
122
123 public StreamExprBuilder<T> filter(Predicate<T> f) {
124 streamOps.add(new FilterStreamOp(f));
125 return this;
126 }
127
128 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
129 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
130 }
131
132 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
133 BiConsumer<Block.Builder, Value> terminalConsumer) {
134 if (i == streamOps.size()) {
135 terminalConsumer.accept(body, element);
136 return;
137 }
138
139 StreamOp sop = streamOps.get(i);
140 if (sop instanceof MapStreamOp) {
141 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
142 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
143 });
144 } else if (sop instanceof FilterStreamOp) {
145 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
146 Block.Builder _if = block.block();
147 Block.Builder _else = continueBlock;
148 if (continueBlock == null) {
149 _else = block.block();
150 _else.op(JavaOp.continue_());
151 }
152
153 block.op(conditionalBranch(p, _if.successor(), _else.successor()));
154
155 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
156 });
157 } else if (sop instanceof FlatMapStreamOp) {
158 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
159 EnhancedForOp forOp = enhancedFor(block.parentBody(),
160 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
161 .expression(b -> {
162 b.op(core_yield(iterable));
163 })
164 .definition(b -> {
165 b.op(core_yield(b.parameters().get(0)));
166 })
167 .body(b -> {
168 fuseIntermediateOperation(i + 1,
169 b,
170 b.parameters().get(0),
171 null, terminalConsumer);
172 });
173
174 block.op(forOp);
175 block.op(JavaOp.continue_());
176 });
177 }
178 }
179
180 public FuncOp forEach(Consumer<T> reflectableConsumer) {
181 Quoted<JavaOp.LambdaOp> quotedConsumer = Op.ofLambda(reflectableConsumer).orElseThrow();
182 if (!(quotedConsumer.capturedValues().isEmpty())) {
183 throw new IllegalArgumentException("Reflectable consumer captures values");
184 }
185 JavaOp.LambdaOp consumer = quotedConsumer.op();
186
187 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
188 .body(b -> {
189 Value source = b.parameters().get(0);
190
191 Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
192 .apply(loopBlock -> {
193 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
194 Inliner.inline(terminalBlock, consumer, List.of(resultValue),
195 (_, _) -> {
196 });
197 terminalBlock.op(JavaOp.continue_());
198 });
199
200 });
201 b.op(sourceLoop);
202 b.op(return_());
203 });
204 }
205
206 public <C> FuncOp collect(Supplier<C> reflectableSupplier, BiConsumer<C, T> reflectableAccumulator) {
207 Quoted<JavaOp.LambdaOp> quotedSupplier = Op.ofLambda(reflectableSupplier).orElseThrow();
208 if (!(quotedSupplier.capturedValues().isEmpty())) {
209 throw new IllegalArgumentException("Reflectable supplier captures values");
210 }
211 JavaOp.LambdaOp supplier = quotedSupplier.op();
212 Quoted<JavaOp.LambdaOp> quotedAccumulator = Op.ofLambda(reflectableAccumulator).orElseThrow();
213 if (!(quotedAccumulator.capturedValues().isEmpty())) {
214 throw new IllegalArgumentException("Reflectable accumulator captures values");
215 }
216 JavaOp.LambdaOp accumulator = quotedAccumulator.op();
217
218 JavaType collectType = (JavaType) supplier.invokableType().returnType();
219 return func("fused.collect", CoreType.functionType(collectType, sourceType))
220 .body(b -> {
221 Value source = b.parameters().get(0);
222
223 Inliner.inline(b, supplier, List.of(), (block, collect) -> {
224 Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
225 .apply(loopBlock -> {
226 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
227 Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
228 (_, _) -> {
229 });
230 terminalBlock.op(JavaOp.continue_());
231 });
232 });
233 block.op(sourceLoop);
234 block.op(return_(collect));
235 });
236 });
237 }
238
239 }
240 }
241