1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.*; 25 import jdk.incubator.code.analysis.Inliner; 26 import jdk.incubator.code.dialect.core.CoreOp; 27 import jdk.incubator.code.dialect.core.CoreType; 28 import jdk.incubator.code.dialect.java.JavaOp; 29 import jdk.incubator.code.dialect.java.JavaOp.EnhancedForOp; 30 import jdk.incubator.code.dialect.java.ClassType; 31 import jdk.incubator.code.dialect.java.JavaType; 32 import java.util.ArrayList; 33 import java.util.List; 34 import java.util.function.BiConsumer; 35 import java.util.function.BiFunction; 36 import java.util.function.Consumer; 37 import java.util.function.Function; 38 39 import static jdk.incubator.code.dialect.core.CoreOp.*; 40 import static jdk.incubator.code.dialect.java.JavaOp.continue_; 41 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor; 42 import static jdk.incubator.code.dialect.java.JavaType.parameterized; 43 import static jdk.incubator.code.dialect.java.JavaType.type; 44 45 public final class StreamFuser { 46 47 StreamFuser() {} 48 49 public static StreamExprBuilder fromList(JavaType elementType) { 50 // java.util.List<E> 51 JavaType listType = parameterized(type(List.class), elementType); 52 return new StreamExprBuilder(listType, elementType, 53 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body); 54 } 55 56 public static class StreamExprBuilder { 57 static class StreamOp { 58 final Quoted quotedClosure; 59 60 StreamOp(Quoted quotedClosure) { 61 if (!(quotedClosure.op() instanceof CoreOp.ClosureOp)) { 62 throw new IllegalArgumentException("Quoted operation is not closure operation"); 63 } 64 this.quotedClosure = quotedClosure; 65 } 66 67 CoreOp.ClosureOp op() { 68 return (CoreOp.ClosureOp) quotedClosure.op(); 69 } 70 } 71 72 static class MapStreamOp extends StreamOp { 73 public MapStreamOp(Quoted quotedClosure) { 74 super(quotedClosure); 75 // @@@ Check closure signature 76 } 77 } 78 79 static class FlatMapStreamOp extends StreamOp { 80 public FlatMapStreamOp(Quoted quotedClosure) { 81 super(quotedClosure); 82 // @@@ Check closure signature 83 } 84 } 85 86 static class FilterStreamOp extends StreamOp { 87 public FilterStreamOp(Quoted quotedClosure) { 88 super(quotedClosure); 89 // @@@ Check closure signature 90 } 91 } 92 93 final JavaType sourceType; 94 final JavaType sourceElementType; 95 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier; 96 final List<StreamOp> streamOps; 97 98 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType, 99 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) { 100 this.sourceType = sourceType; 101 this.sourceElementType = sourceElementType; 102 this.loopSupplier = loopSupplier; 103 this.streamOps = new ArrayList<>(); 104 } 105 106 static EnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType, 107 Value iterable) { 108 return enhancedFor(ancestorBody, iterable.type(), elementType) 109 .expression(b -> { 110 b.op(core_yield(iterable)); 111 }) 112 .definition(b -> { 113 b.op(core_yield(b.parameters().get(0))); 114 }); 115 } 116 117 public StreamExprBuilder map(Quoted f) { 118 streamOps.add(new MapStreamOp(f)); 119 return this; 120 } 121 122 public StreamExprBuilder flatMap(Quoted f) { 123 streamOps.add(new FlatMapStreamOp(f)); 124 return this; 125 } 126 127 public StreamExprBuilder filter(Quoted f) { 128 streamOps.add(new FilterStreamOp(f)); 129 return this; 130 } 131 132 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) { 133 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer); 134 } 135 136 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock, 137 BiConsumer<Block.Builder, Value> terminalConsumer) { 138 if (i == streamOps.size()) { 139 terminalConsumer.accept(body, element); 140 return; 141 } 142 143 StreamOp sop = streamOps.get(i); 144 if (sop instanceof MapStreamOp) { 145 Inliner.inline(body, sop.op(), List.of(element), (block, value) -> { 146 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer); 147 }); 148 } else if (sop instanceof FilterStreamOp) { 149 Inliner.inline(body, sop.op(), List.of(element), (block, p) -> { 150 Block.Builder _if = block.block(); 151 Block.Builder _else = continueBlock; 152 if (continueBlock == null) { 153 _else = block.block(); 154 _else.op(JavaOp.continue_()); 155 } 156 157 block.op(conditionalBranch(p, _if.successor(), _else.successor())); 158 159 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer); 160 }); 161 } else if (sop instanceof FlatMapStreamOp) { 162 Inliner.inline(body, sop.op(), List.of(element), (block, iterable) -> { 163 EnhancedForOp forOp = enhancedFor(block.parentBody(), 164 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0)) 165 .expression(b -> { 166 b.op(core_yield(iterable)); 167 }) 168 .definition(b -> { 169 b.op(core_yield(b.parameters().get(0))); 170 }) 171 .body(b -> { 172 fuseIntermediateOperation(i + 1, 173 b, 174 b.parameters().get(0), 175 null, terminalConsumer); 176 }); 177 178 block.op(forOp); 179 block.op(JavaOp.continue_()); 180 }); 181 } 182 } 183 184 public FuncOp forEach(Quoted quotedConsumer) { 185 if (!(quotedConsumer.op() instanceof CoreOp.ClosureOp consumer)) { 186 throw new IllegalArgumentException("Quoted consumer is not closure operation"); 187 } 188 189 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType)) 190 .body(b -> { 191 Value source = b.parameters().get(0); 192 193 Op sourceLoop = loopSupplier.apply(b.parentBody(), source) 194 .apply(loopBlock -> { 195 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> { 196 Inliner.inline(terminalBlock, consumer, List.of(resultValue), 197 (_, _) -> { 198 }); 199 terminalBlock.op(JavaOp.continue_()); 200 }); 201 202 }); 203 b.op(sourceLoop); 204 b.op(return_()); 205 }); 206 } 207 208 // Supplier<C> supplier, BiConsumer<C, T> accumulator 209 public FuncOp collect(Quoted quotedSupplier, Quoted quotedAccumulator) { 210 if (!(quotedSupplier.op() instanceof CoreOp.ClosureOp supplier)) { 211 throw new IllegalArgumentException("Quoted supplier is not closure operation"); 212 } 213 if (!(quotedAccumulator.op() instanceof CoreOp.ClosureOp accumulator)) { 214 throw new IllegalArgumentException("Quoted accumulator is not closure operation"); 215 } 216 217 JavaType collectType = (JavaType) supplier.invokableType().returnType(); 218 return func("fused.collect", CoreType.functionType(collectType, sourceType)) 219 .body(b -> { 220 Value source = b.parameters().get(0); 221 222 Inliner.inline(b, supplier, List.of(), (block, collect) -> { 223 Op sourceLoop = loopSupplier.apply(block.parentBody(), source) 224 .apply(loopBlock -> { 225 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> { 226 Inliner.inline(terminalBlock, accumulator, List.of(collect, resultValue), 227 (_, _) -> { 228 }); 229 terminalBlock.op(JavaOp.continue_()); 230 }); 231 }); 232 block.op(sourceLoop); 233 block.op(return_(collect)); 234 }); 235 }); 236 } 237 238 } 239 }