1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.*;
25 import jdk.incubator.code.analysis.Inliner;
26 import jdk.incubator.code.dialect.core.CoreType;
27 import jdk.incubator.code.dialect.java.JavaOp;
28 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp;
29 import jdk.incubator.code.dialect.java.ClassType;
30 import jdk.incubator.code.dialect.java.JavaType;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.function.*;
34
35 import static jdk.incubator.code.dialect.core.CoreOp.*;
36 import static jdk.incubator.code.dialect.java.JavaOp.continue_;
37 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor;
38 import static jdk.incubator.code.dialect.java.JavaType.parameterized;
39 import static jdk.incubator.code.dialect.java.JavaType.type;
40
41 public final class StreamFuserUsingQuotable {
42
43 // Quotable functional interfaces
44
45 public interface QuotablePredicate<T> extends Quotable, Predicate<T> {
46 }
47
48 public interface QuotableFunction<T, R> extends Quotable, Function<T, R> {
49 }
50
51 public interface QuotableSupplier<T> extends Quotable, Supplier<T> {
52 }
53
54 public interface QuotableConsumer<T> extends Quotable, Consumer<T> {
55 }
56
57 public interface QuotableBiConsumer<T, U> extends Quotable, BiConsumer<T, U> {
58 }
59
60
61 StreamFuserUsingQuotable() {}
62
63 public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) {
64 JavaType elementType = type(elementClass);
65 // java.util.List<E>
66 JavaType listType = parameterized(type(List.class), elementType);
67 return new StreamExprBuilder<>(listType, elementType,
68 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body);
69 }
70
71 public static class StreamExprBuilder<T> {
72 static class StreamOp {
73 final JavaOp.LambdaOp lambdaOp;
74
75 StreamOp(Quotable quotedLambda) {
76 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) {
77 throw new IllegalArgumentException("Quotable operation is not lambda operation");
78 }
79 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) {
80 throw new IllegalArgumentException("Quotable operation captures values");
81 }
82 this.lambdaOp = lambdaOp;
83 }
84
85 JavaOp.LambdaOp op() {
86 return lambdaOp;
87 }
88 }
89
90 static class MapStreamOp extends StreamOp {
91 public MapStreamOp(Quotable quotedLambda) {
92 super(quotedLambda);
93 }
94 }
95
96 static class FlatMapStreamOp extends StreamOp {
97 public FlatMapStreamOp(Quotable quotedLambda) {
98 super(quotedLambda);
99 }
100 }
101
102 static class FilterStreamOp extends StreamOp {
103 public FilterStreamOp(Quotable quotedLambda) {
104 super(quotedLambda);
105 }
106 }
107
108 final JavaType sourceType;
109 final JavaType sourceElementType;
110 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier;
111 final List<StreamOp> streamOps;
112
113 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType,
114 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) {
115 this.sourceType = sourceType;
116 this.sourceElementType = sourceElementType;
117 this.loopSupplier = loopSupplier;
118 this.streamOps = new ArrayList<>();
119 }
120
121 static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType,
122 Value iterable) {
123 return enhancedFor(ancestorBody, iterable.type(), elementType)
124 .expression(b -> {
125 b.op(core_yield(iterable));
126 })
127 .definition(b -> {
128 b.op(core_yield(b.parameters().get(0)));
129 });
130 }
131
132 @SuppressWarnings("unchecked")
133 public <R> StreamExprBuilder<R> map(QuotableFunction<T, R> f) {
134 streamOps.add(new MapStreamOp(f));
135 return (StreamExprBuilder<R>) this;
136 }
137
138 @SuppressWarnings("unchecked")
139 public <R> StreamExprBuilder<R> flatMap(QuotableFunction<T, Iterable<R>> f) {
140 streamOps.add(new FlatMapStreamOp(f));
141 return (StreamExprBuilder<R>) this;
142 }
143
144 public StreamExprBuilder<T> filter(QuotablePredicate<T> f) {
145 streamOps.add(new FilterStreamOp(f));
146 return this;
147 }
148
149 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) {
150 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer);
151 }
152
153 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock,
154 BiConsumer<Block.Builder, Value> terminalConsumer) {
155 if (i == streamOps.size()) {
156 terminalConsumer.accept(body, element);
157 return;
158 }
159
160 StreamOp sop = streamOps.get(i);
161 if (sop instanceof MapStreamOp) {
162 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> {
163 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer);
164 });
165 } else if (sop instanceof FilterStreamOp) {
166 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> {
167 Block.Builder _if = block.block();
168 Block.Builder _else = continueBlock;
169 if (continueBlock == null) {
170 _else = block.block();
171 _else.op(JavaOp.continue_());
172 }
173
174 block.op(conditionalBranch(p, _if.successor(), _else.successor()));
175
176 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer);
177 });
178 } else if (sop instanceof FlatMapStreamOp) {
179 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> {
180 EnhancedForOp forOp = enhancedFor(block.parentBody(),
181 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0))
182 .expression(b -> {
183 b.op(core_yield(iterable));
184 })
185 .definition(b -> {
186 b.op(core_yield(b.parameters().get(0)));
187 })
188 .body(b -> {
189 fuseIntermediateOperation(i + 1,
190 b,
191 b.parameters().get(0),
192 null, terminalConsumer);
193 });
194
195 block.op(forOp);
196 block.op(JavaOp.continue_());
197 });
198 }
199 }
200
201 public FuncOp forEach(QuotableConsumer<T> quotableConsumer) {
202 if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) {
203 throw new IllegalArgumentException("Quotable consumer is not lambda operation");
204 }
205 if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) {
206 throw new IllegalArgumentException("Quotable consumer captures values");
207 }
208
209 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType))
210 .body(b -> {
211 Value source = b.parameters().get(0);
212
213 Op sourceLoop = loopSupplier.apply(b.parentBody(), source)
214 .apply(loopBlock -> {
215 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
216 Inliner.inline(terminalBlock, consumer, List.of(resultValue),
217 (_, _) -> {
218 });
219 terminalBlock.op(JavaOp.continue_());
220 });
221
222 });
223 b.op(sourceLoop);
224 b.op(return_());
225 });
226 }
227
228 public <C> FuncOp collect(QuotableSupplier<C> quotableSupplier, QuotableBiConsumer<C, T> quotableAccumulator) {
229 if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) {
230 throw new IllegalArgumentException("Quotable supplier is not lambda operation");
231 }
232 if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) {
233 throw new IllegalArgumentException("Quotable supplier captures values");
234 }
235 if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) {
236 throw new IllegalArgumentException("Quotable accumulator is not lambda operation");
237 }
238 if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) {
239 throw new IllegalArgumentException("Quotable accumulator captures values");
240 }
241
242 JavaType collectType = (JavaType) supplier.invokableType().returnType();
243 return func("fused.collect", CoreType.functionType(collectType, sourceType))
244 .body(b -> {
245 Value source = b.parameters().get(0);
246
247 Inliner.inline(b, supplier, List.of(), (block, collect) -> {
248 Op sourceLoop = loopSupplier.apply(block.parentBody(), source)
249 .apply(loopBlock -> {
250 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> {
251 Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue),
252 (_, _) -> {
253 });
254 terminalBlock.op(JavaOp.continue_());
255 });
256 });
257 block.op(sourceLoop);
258 block.op(return_(collect));
259 });
260 });
261 }
262
263 }
264 }
265