1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 
 26 package jdk.incubator.code;
 27 
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 import jdk.incubator.code.dialect.core.CoreType;
 30 import jdk.incubator.code.dialect.core.FunctionType;
 31 import jdk.incubator.code.dialect.core.VarType;
 32 
 33 import java.util.*;
 34 import java.util.function.Consumer;
 35 
 36 /**
 37  * The quoted form of an operation.
 38  * <p>
 39  * The quoted form is utilized when the code model of some code is to be obtained rather than obtaining the result of
 40  * executing that code. For example passing the of a lambda expression in quoted form rather than the expression being
 41  * targeted to a functional interface from which it can be invoked.
 42  * @param <T> the type of operation that is quoted
 43  */
 44 public final class Quoted<T extends Op> {
 45     private final T op;
 46     private final SequencedMap<Value, Object> operandsAndCapturedValues;
 47 
 48     static final SequencedMap<Value, Object> EMPTY_SEQUENCED_MAP = new LinkedHashMap<>();
 49     /**
 50      * Constructs the quoted form of a given operation.
 51      *
 52      * @param op the invokable operation.
 53      */
 54     public Quoted(T op) {
 55         this(op, EMPTY_SEQUENCED_MAP);
 56     }
 57 
 58     /**
 59      * Constructs the quoted form of a given operation.
 60      * <p>
 61      * The {@code operandsAndCapturedValues} key set must be equal to
 62      * the sequenced set of operation's operands + captured values, in order.
 63      *
 64      * @param op                        the operation.
 65      * @param operandsAndCapturedValues sequenced map of {@link Value} to {@link Object}, with the requirement defined above
 66      * @throws IllegalArgumentException If {@code operandsAndCapturedValues} doesn't satisfy the requirement described above
 67      * @see Op#capturedValues()
 68      * @see Op#operands()
 69      */
 70     public Quoted(T op, SequencedMap<Value, Object> operandsAndCapturedValues) {
 71         // @@@ This check is potentially expensive, remove or keep ?
 72         // @@@ Or make Quoted an interface, with a module private implementation?
 73         SequencedSet<Value> s = new LinkedHashSet<>(op.operands());
 74         s.addAll(op.capturedValues());
 75         if (!s.stream().toList().equals(operandsAndCapturedValues.keySet().stream().toList())) {
 76             throw new IllegalArgumentException("The map key set isn't equal to the sequenced set of operands + captured values");
 77         }
 78         this.op = op;
 79         this.operandsAndCapturedValues = Collections.unmodifiableSequencedMap(new LinkedHashMap<>(operandsAndCapturedValues));
 80     }
 81 
 82     /**
 83      * Returns the operation.
 84      *
 85      * @return the operation.
 86      */
 87     public T op() {
 88         return op;
 89     }
 90 
 91     /**
 92      * Returns the captured values.
 93      * <p>
 94      * The captured values key set has the same elements and same encounter order as
 95      * operation's captured values, specifically the following expression evaluates to true:
 96      * {@snippet lang=java :
 97      * op().capturedValues().equals(new ArrayList<>(capturedValues().keySet()));
 98      * }
 99      *
100      * @return the captured values.
101      */
102     public SequencedMap<Value, Object> capturedValues() {
103         SequencedMap<Value, Object> m = new LinkedHashMap<>();
104         for (Value cv : op.capturedValues()) {
105             m.put(cv, operandsAndCapturedValues.get(cv));
106         }
107         return m;
108     }
109 
110     /**
111      * Returns the operands.
112      * <p>
113      * The result key set has the same elements and same encounter order as the sequenced set of operation's operands,
114      * specifically the following expression evaluates to true:
115      * {@snippet lang = java:
116      * new LinkedHashSet<>(op.operands()).equals(operands().keySet());
117      *}
118      *
119      * @return the operands.
120      */
121     public SequencedMap<Value, Object> operands() {
122         SequencedMap<Value, Object> m = new LinkedHashMap<>();
123         for (Value operand : op.operands()) {
124             // putIfAbsent is used because a value may be used as operand more than once
125             m.putIfAbsent(operand, operandsAndCapturedValues.get(operand));
126         }
127         return m;
128     }
129 
130     /**
131      * Returns the operands and captured values.
132      * The result key set is equal to the sequenced set of operands + captured values.
133      *
134      * @return the operands + captured values, as an unmodifiable map.
135      */
136     public SequencedMap<Value, Object> operandsAndCapturedValues() {
137         return operandsAndCapturedValues;
138     }
139 
140     /**
141      * Embeds the given {@code op}, copying it from its original context to a new one,
142      * where its operands and captured values will be parameters.
143      * <p>
144      * The result is a {@link jdk.incubator.code.dialect.core.CoreOp.FuncOp FuncOp}
145      * that has one body with one block (<i>fblock</i>).
146      * <br>
147      * <i>fblock</i> will have a parameter for every element in the sequenced set of {@code op}'s operands + captured values.
148      * If the operand or capture is a result of a {@link jdk.incubator.code.dialect.core.CoreOp.VarOp VarOp},
149      * <i>fblock</i> will have a {@link jdk.incubator.code.dialect.core.CoreOp.VarOp VarOp}
150      * whose initial value is the parameter.
151      * <br>
152      * Then <i>fblock</i> has a {@link jdk.incubator.code.dialect.core.CoreOp.QuotedOp QuotedOp}
153      * that has one body with one block (<i>qblock</i>).
154      * Inside <i>qblock</i> we find a copy of {@code op}
155      * and a {@link jdk.incubator.code.dialect.core.CoreOp.YieldOp YieldOp}
156      * whose yield value is the result of the {@code op}'s copy.
157      * <br>
158      * <i>fblock</i> terminates with a {@link jdk.incubator.code.dialect.core.CoreOp.ReturnOp ReturnOp},
159      * the returned value is the result of the {@link jdk.incubator.code.dialect.core.CoreOp.QuotedOp QuotedOp}
160      * object described previously.
161      *
162      * @param op The operation to embed
163      * @return The model that represent the quoting of {@code op}
164      * @throws IllegalArgumentException if {@code op} is not bound
165      */
166     public static CoreOp.FuncOp embedOp(Op op) {
167         if (op.result() == null) {
168             throw new IllegalArgumentException("Op not bound");
169         }
170 
171         // if we don't remove duplicate operands we will have unused params in the new model
172         // if we don't remove captured values that are operands we will have unused params in the new model
173         SequencedSet<Value> s = new LinkedHashSet<>(op.operands());
174         s.addAll(op.capturedValues());
175         List<Value> inputOperandsAndCaptures = s.stream().toList();
176 
177         // Build the function type
178         List<TypeElement> params = inputOperandsAndCaptures.stream()
179                 .map(v -> v.type() instanceof VarType vt ? vt.valueType() : v.type())
180                 .toList();
181         FunctionType ft = CoreType.functionType(CoreOp.QuotedOp.QUOTED_OP_TYPE, params);
182 
183         // Build the function that quotes the lambda
184         return CoreOp.func("q", ft).body(b -> {
185             // Create variables as needed and obtain the operands and captured values for the copied lambda
186             List<Value> outputOperandsAndCaptures = new ArrayList<>();
187             for (int i = 0; i < inputOperandsAndCaptures.size(); i++) {
188                 Value inputValue = inputOperandsAndCaptures.get(i);
189                 Value outputValue = b.parameters().get(i);
190                 if (inputValue.type() instanceof VarType) {
191                     outputValue = b.op(CoreOp.var(String.valueOf(i), outputValue));
192                 }
193                 outputOperandsAndCaptures.add(outputValue);
194             }
195 
196             // Quoted the lambda expression
197             Value q = b.op(CoreOp.quoted(b.parentBody(), qb -> {
198                 // Map the entry block of the op's ancestor body to the quoted block
199                 // We are copying op in the context of the quoted block, the block mapping
200                 // ensures the use of operands and captured values are reachable when building
201                 qb.context().mapBlock(op.ancestorBody().entryBlock(), qb);
202                 // Map the op's operands and captured values
203                 qb.context().mapValues(inputOperandsAndCaptures, outputOperandsAndCaptures);
204                 // Return the op to be copied in the quoted operation
205                 return op;
206             }));
207             b.op(CoreOp.return_(q));
208         });
209     }
210 
211     private static RuntimeException invalidQuotedModel(CoreOp.FuncOp model) {
212         return new RuntimeException("Invalid code model for quoted operation : " + model);
213     }
214 
215     /**
216      * Extracts the quoted operation from {@code funcOp}
217      * and map its operands and captured values to the runtime values in {@code args}.
218      * <p>
219      * {@code funcOp} must have the same structure as if it's produced by {@link #embedOp(Op)}.
220      *
221      * @param funcOp Model to extract the quoted op from
222      * @param args   Runtime values for {@code funcOp} parameters
223      * @return Quoted instance that wraps the quoted operation,
224      * plus the mapping of its operands and captured values to the given runtime values
225      * @throws RuntimeException If {@code funcOp} isn't a valid code model
226      * @throws RuntimeException If {@code funcOp} parameters size is different from {@code args} length
227      */
228     public static Quoted<Op> extractOp(CoreOp.FuncOp funcOp, List<Object> args) {
229         if (funcOp.body().blocks().size() != 1) {
230             throw invalidQuotedModel(funcOp);
231         }
232         Block fblock = funcOp.body().entryBlock();
233         if (fblock.ops().size() < 2) {
234             throw invalidQuotedModel(funcOp);
235         }
236         if (!(fblock.ops().get(fblock.ops().size() - 2) instanceof CoreOp.QuotedOp qop)) {
237             throw invalidQuotedModel(funcOp);
238         }
239         if (!(fblock.ops().getLast() instanceof CoreOp.ReturnOp returnOp)) {
240             throw invalidQuotedModel(funcOp);
241         }
242         if (returnOp.returnValue() == null) {
243             throw invalidQuotedModel(funcOp);
244         }
245         if (!returnOp.returnValue().equals(qop.result())) {
246             throw invalidQuotedModel(funcOp);
247         }
248 
249         Op op = qop.quotedOp();
250 
251         SequencedSet<Value> operandsAndCaptures = new LinkedHashSet<>();
252         operandsAndCaptures.addAll(op.operands());
253         operandsAndCaptures.addAll(op.capturedValues());
254 
255         // validation rule of block params and ConstantOp result
256         // let v be a block param or ConstantOp result
257         // if v not used -> throw
258         // if v used once and user is VarOp and VarOp not used or VarOp used in funcOp entry block -> throw
259         // if v is used once and user is not a VarOp and usage isn't as operand or capture -> throw
260         // if v is used more than once and one of the uses is in funcOp entry block -> throw
261         Consumer<Value> validate = v -> {
262             if (v.uses().isEmpty()) {
263                 throw invalidQuotedModel(funcOp);
264             } else if (v.uses().size() == 1 && v.uses().iterator().next().op() instanceof CoreOp.VarOp vop
265                     && (vop.result().uses().isEmpty() ||
266                     vop.result().uses().stream().anyMatch(u -> u.op().ancestorBlock() == fblock))) {
267                 throw invalidQuotedModel(funcOp);
268             } else if (v.uses().size() == 1 && !(v.uses().iterator().next().op() instanceof CoreOp.VarOp)
269                     && !operandsAndCaptures.contains(v)) {
270                 throw invalidQuotedModel(funcOp);
271             } else if (v.uses().size() > 1 && v.uses().stream().anyMatch(u -> u.op().ancestorBlock() == fblock)) {
272                 throw invalidQuotedModel(funcOp);
273             }
274         };
275 
276         for (Block.Parameter p : fblock.parameters()) {
277             validate.accept(p);
278         }
279 
280         List<Op> ops = fblock.ops().subList(0, fblock.ops().size() - 2);
281         for (Op o : ops) {
282             switch (o) {
283                 case CoreOp.VarOp varOp -> {
284                     if (varOp.isUninitialized()) {
285                         throw invalidQuotedModel(funcOp);
286                     }
287                     if (varOp.initOperand() instanceof Op.Result opr && !(opr.op() instanceof CoreOp.ConstantOp)) {
288                         throw invalidQuotedModel(funcOp);
289                     }
290                 }
291                 case CoreOp.ConstantOp cop -> validate.accept(cop.result());
292                 default -> throw invalidQuotedModel(funcOp);
293             }
294         }
295 
296         // map operands and captures to their corresponding runtime values
297         // operand and capture can be:
298         // 1- block param
299         // 2- result of VarOp whose initial value is constant
300         // 3- result of VarOp whose initial value is block param
301         // 4- result of ConstantOp
302         List<Block.Parameter> params = funcOp.parameters();
303         if (params.size() != args.size()) {
304             throw invalidQuotedModel(funcOp);
305         }
306         SequencedMap<Value, Object> m = new LinkedHashMap<>();
307         for (Value v : operandsAndCaptures) {
308             switch (v) {
309                 case Block.Parameter p -> {
310                     Object rv = args.get(p.index());
311                     m.put(v, rv);
312                 }
313                 case Op.Result opr when opr.op() instanceof CoreOp.VarOp varOp -> {
314                     if (varOp.initOperand() instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp cop) {
315                         m.put(v, CoreOp.Var.of(cop.value()));
316                     } else if (varOp.initOperand() instanceof Block.Parameter p) {
317                         Object rv = args.get(p.index());
318                         m.put(v, CoreOp.Var.of(rv));
319                     }
320                 }
321                 case Op.Result opr when opr.op() instanceof CoreOp.ConstantOp cop -> {
322                     m.put(v, cop.value());
323                 }
324                 default -> throw invalidQuotedModel(funcOp);
325             }
326         }
327 
328         return new Quoted<>(op, m);
329     }
330 
331     /**
332      * Extracts the quoted operation from {@code funcOp}
333      * and map its operands and captured values to the runtime values in {@code args}.
334      * <p>
335      * {@code funcOp} must have the same structure as if it's produced by {@link #embedOp(Op)}.
336      *
337      * @param funcOp Model to extract the quoted op from
338      * @param args   Runtime values for {@code funcOp} parameters
339      * @return Quoted instance that wraps the quoted operation,
340      * plus the mapping of its operands and captured values to the given runtime values
341      * @throws RuntimeException If {@code funcOp} isn't a valid code model
342      * @throws RuntimeException If {@code funcOp} parameters size is different from {@code args} length
343      * @see Quoted#extractOp(CoreOp.FuncOp, List)
344      */
345     public static Quoted<Op> extractOp(CoreOp.FuncOp funcOp, Object... args) {
346         return extractOp(funcOp, List.of(args));
347     }
348 }