1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.*; 25 import jdk.incubator.code.dialect.core.CoreType; 26 import jdk.incubator.code.dialect.java.JavaOp; 27 import jdk.incubator.code.dialect.java.JavaOp.JavaEnhancedForOp; 28 import jdk.incubator.code.dialect.java.ClassType; 29 import jdk.incubator.code.dialect.java.JavaType; 30 import java.util.ArrayList; 31 import java.util.List; 32 import java.util.function.*; 33 34 import static jdk.incubator.code.dialect.core.CoreOp.*; 35 import static jdk.incubator.code.dialect.java.JavaOp._continue; 36 import static jdk.incubator.code.dialect.java.JavaOp.enhancedFor; 37 import static jdk.incubator.code.dialect.java.JavaType.parameterized; 38 import static jdk.incubator.code.dialect.java.JavaType.type; 39 40 public final class StreamFuserUsingQuotable { 41 42 // Quotable functional interfaces 43 44 public interface QuotablePredicate<T> extends Quotable, Predicate<T> { 45 } 46 47 public interface QuotableFunction<T, R> extends Quotable, Function<T, R> { 48 } 49 50 public interface QuotableSupplier<T> extends Quotable, Supplier<T> { 51 } 52 53 public interface QuotableConsumer<T> extends Quotable, Consumer<T> { 54 } 55 56 public interface QuotableBiConsumer<T, U> extends Quotable, BiConsumer<T, U> { 57 } 58 59 60 StreamFuserUsingQuotable() {} 61 62 public static <T> StreamExprBuilder<T> fromList(Class<T> elementClass) { 63 JavaType elementType = type(elementClass); 64 // java.util.List<E> 65 JavaType listType = parameterized(type(List.class), elementType); 66 return new StreamExprBuilder<>(listType, elementType, 67 (b, v) -> StreamExprBuilder.enhancedForLoop(b, elementType, v)::body); 68 } 69 70 public static class StreamExprBuilder<T> { 71 static class StreamOp { 72 final JavaOp.LambdaOp lambdaOp; 73 74 StreamOp(Quotable quotedLambda) { 75 if (!(Op.ofQuotable(quotedLambda).get().op() instanceof JavaOp.LambdaOp lambdaOp)) { 76 throw new IllegalArgumentException("Quotable operation is not lambda operation"); 77 } 78 if (!(Op.ofQuotable(quotedLambda).get().capturedValues().isEmpty())) { 79 throw new IllegalArgumentException("Quotable operation captures values"); 80 } 81 this.lambdaOp = lambdaOp; 82 } 83 84 JavaOp.LambdaOp op() { 85 return lambdaOp; 86 } 87 } 88 89 static class MapStreamOp extends StreamOp { 90 public MapStreamOp(Quotable quotedLambda) { 91 super(quotedLambda); 92 } 93 } 94 95 static class FlatMapStreamOp extends StreamOp { 96 public FlatMapStreamOp(Quotable quotedLambda) { 97 super(quotedLambda); 98 } 99 } 100 101 static class FilterStreamOp extends StreamOp { 102 public FilterStreamOp(Quotable quotedLambda) { 103 super(quotedLambda); 104 } 105 } 106 107 final JavaType sourceType; 108 final JavaType sourceElementType; 109 final BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier; 110 final List<StreamOp> streamOps; 111 112 StreamExprBuilder(JavaType sourceType, JavaType sourceElementType, 113 BiFunction<Body.Builder, Value, Function<Consumer<Block.Builder>, Op>> loopSupplier) { 114 this.sourceType = sourceType; 115 this.sourceElementType = sourceElementType; 116 this.loopSupplier = loopSupplier; 117 this.streamOps = new ArrayList<>(); 118 } 119 120 static JavaEnhancedForOp.BodyBuilder enhancedForLoop(Body.Builder ancestorBody, JavaType elementType, 121 Value iterable) { 122 return enhancedFor(ancestorBody, iterable.type(), elementType) 123 .expression(b -> { 124 b.op(_yield(iterable)); 125 }) 126 .definition(b -> { 127 b.op(_yield(b.parameters().get(0))); 128 }); 129 } 130 131 @SuppressWarnings("unchecked") 132 public <R> StreamExprBuilder<R> map(QuotableFunction<T, R> f) { 133 streamOps.add(new MapStreamOp(f)); 134 return (StreamExprBuilder<R>) this; 135 } 136 137 @SuppressWarnings("unchecked") 138 public <R> StreamExprBuilder<R> flatMap(QuotableFunction<T, Iterable<R>> f) { 139 streamOps.add(new FlatMapStreamOp(f)); 140 return (StreamExprBuilder<R>) this; 141 } 142 143 public StreamExprBuilder<T> filter(QuotablePredicate<T> f) { 144 streamOps.add(new FilterStreamOp(f)); 145 return this; 146 } 147 148 void fuseIntermediateOperations(Block.Builder body, BiConsumer<Block.Builder, Value> terminalConsumer) { 149 fuseIntermediateOperation(0, body, body.parameters().get(0), null, terminalConsumer); 150 } 151 152 void fuseIntermediateOperation(int i, Block.Builder body, Value element, Block.Builder continueBlock, 153 BiConsumer<Block.Builder, Value> terminalConsumer) { 154 if (i == streamOps.size()) { 155 terminalConsumer.accept(body, element); 156 return; 157 } 158 159 StreamOp sop = streamOps.get(i); 160 if (sop instanceof MapStreamOp) { 161 body.inline(sop.op(), List.of(element), (block, value) -> { 162 fuseIntermediateOperation(i + 1, block, value, continueBlock, terminalConsumer); 163 }); 164 } else if (sop instanceof FilterStreamOp) { 165 body.inline(sop.op(), List.of(element), (block, p) -> { 166 Block.Builder _if = block.block(); 167 Block.Builder _else = continueBlock; 168 if (continueBlock == null) { 169 _else = block.block(); 170 _else.op(_continue()); 171 } 172 173 block.op(conditionalBranch(p, _if.successor(), _else.successor())); 174 175 fuseIntermediateOperation(i + 1, _if, element, _else, terminalConsumer); 176 }); 177 } else if (sop instanceof FlatMapStreamOp) { 178 body.inline(sop.op(), List.of(element), (block, iterable) -> { 179 JavaEnhancedForOp forOp = enhancedFor(block.parentBody(), 180 iterable.type(), ((ClassType) iterable.type()).typeArguments().get(0)) 181 .expression(b -> { 182 b.op(_yield(iterable)); 183 }) 184 .definition(b -> { 185 b.op(_yield(b.parameters().get(0))); 186 }) 187 .body(b -> { 188 fuseIntermediateOperation(i + 1, 189 b, 190 b.parameters().get(0), 191 null, terminalConsumer); 192 }); 193 194 block.op(forOp); 195 block.op(_continue()); 196 }); 197 } 198 } 199 200 public FuncOp forEach(QuotableConsumer<T> quotableConsumer) { 201 if (!(Op.ofQuotable(quotableConsumer).get().op() instanceof JavaOp.LambdaOp consumer)) { 202 throw new IllegalArgumentException("Quotable consumer is not lambda operation"); 203 } 204 if (!(Op.ofQuotable(quotableConsumer).get().capturedValues().isEmpty())) { 205 throw new IllegalArgumentException("Quotable consumer captures values"); 206 } 207 208 return func("fused.forEach", CoreType.functionType(JavaType.VOID, sourceType)) 209 .body(b -> { 210 Value source = b.parameters().get(0); 211 212 Op sourceLoop = loopSupplier.apply(b.parentBody(), source) 213 .apply(loopBlock -> { 214 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> { 215 terminalBlock.inline(consumer, List.of(resultValue), 216 (_, _) -> { 217 }); 218 terminalBlock.op(_continue()); 219 }); 220 221 }); 222 b.op(sourceLoop); 223 b.op(_return()); 224 }); 225 } 226 227 public <C> FuncOp collect(QuotableSupplier<C> quotableSupplier, QuotableBiConsumer<C, T> quotableAccumulator) { 228 if (!(Op.ofQuotable(quotableSupplier).get().op() instanceof JavaOp.LambdaOp supplier)) { 229 throw new IllegalArgumentException("Quotable supplier is not lambda operation"); 230 } 231 if (!(Op.ofQuotable(quotableSupplier).get().capturedValues().isEmpty())) { 232 throw new IllegalArgumentException("Quotable supplier captures values"); 233 } 234 if (!(Op.ofQuotable(quotableAccumulator).get().op() instanceof JavaOp.LambdaOp accumulator)) { 235 throw new IllegalArgumentException("Quotable accumulator is not lambda operation"); 236 } 237 if (!(Op.ofQuotable(quotableAccumulator).get().capturedValues().isEmpty())) { 238 throw new IllegalArgumentException("Quotable accumulator captures values"); 239 } 240 241 JavaType collectType = (JavaType) supplier.invokableType().returnType(); 242 return func("fused.collect", CoreType.functionType(collectType, sourceType)) 243 .body(b -> { 244 Value source = b.parameters().get(0); 245 246 b.inline(supplier, List.of(), (block, collect) -> { 247 Op sourceLoop = loopSupplier.apply(block.parentBody(), source) 248 .apply(loopBlock -> { 249 fuseIntermediateOperations(loopBlock, (terminalBlock, resultValue) -> { 250 terminalBlock.inline(accumulator, List.of(collect, resultValue), 251 (_, _) -> { 252 }); 253 terminalBlock.op(_continue()); 254 }); 255 }); 256 block.op(sourceLoop); 257 block.op(_return(collect)); 258 }); 259 }); 260 } 261 262 } 263 } 264