1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import jdk.incubator.code.Block;
 25 import jdk.incubator.code.CopyContext;
 26 import jdk.incubator.code.Op;
 27 import jdk.incubator.code.Value;
 28 import jdk.incubator.code.dialect.core.CoreOp;
 29 import jdk.incubator.code.dialect.core.CoreType;
 30 import jdk.incubator.code.dialect.java.JavaOp;
 31 import jdk.incubator.code.dialect.java.MethodRef;
 32 import jdk.incubator.code.dialect.core.FunctionType;
 33 import jdk.incubator.code.dialect.java.JavaType;
 34 import java.util.HashMap;
 35 import java.util.List;
 36 import java.util.Map;
 37 import java.util.Set;
 38 import java.util.concurrent.atomic.AtomicBoolean;
 39 
 40 import static jdk.incubator.code.dialect.core.CoreOp.*;
 41 import static jdk.incubator.code.dialect.java.JavaType.DOUBLE;
 42 
 43 public final class ForwardDifferentiation {
 44     // The function to differentiate
 45     final FuncOp fcm;
 46     // The independent variable
 47     final Block.Parameter ind;
 48     // The active set for the independent variable
 49     final Set<Value> activeSet;
 50     // The map of input value to it's (output) differentiated value
 51     final Map<Value, Value> diffValueMapping;
 52 
 53     // The constant value 0.0d
 54     // Declared in the (output) function's entry block
 55     Value zero;
 56 
 57     private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) {
 58         int indI = fcm.body().entryBlock().parameters().indexOf(ind);
 59         if (indI == -1) {
 60             throw new IllegalArgumentException("Independent argument not defined by function");
 61         }
 62         this.fcm = fcm;
 63         this.ind = ind;
 64 
 65         // Calculate the active set of dependent values for the independent value
 66         this.activeSet = ActiveSet.activeSet(fcm, ind);
 67         // A mapping of input values to their (output) differentiated values
 68         this.diffValueMapping = new HashMap<>();
 69     }
 70 
 71     public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) {
 72         return new ForwardDifferentiation(fcm, ind).partialDiff();
 73     }
 74 
 75     FuncOp partialDiff() {
 76         int indI = fcm.body().entryBlock().parameters().indexOf(ind);
 77 
 78         AtomicBoolean first = new AtomicBoolean(true);
 79         FuncOp dfcm = fcm.transform(String.format("d%s_darg%d", fcm.funcName(), indI),
 80                 (block, op) -> {
 81                     if (first.getAndSet(false)) {
 82                         // Initialize
 83                         processBlocks(block);
 84                     }
 85 
 86                     // If the result of the operation is in the active set,
 87                     // then differentiate it, otherwise copy it
 88                     if (activeSet.contains(op.result())) {
 89                         Value dor = diffOp(block, op);
 90                         // Map the input result to its (output) differentiated result
 91                         // so that it can be used when differentiating subsequent operations
 92                         diffValueMapping.put(op.result(), dor);
 93                     } else {
 94                         block.apply(op);
 95                     }
 96                     return block;
 97                 });
 98 
 99         return dfcm;
100     }
101 
102     void processBlocks(Block.Builder block) {
103         // Declare constants at start
104         zero = block.op(constant(ind.type(), 0.0d));
105         // The differential of ind is 1
106         Value one = block.op(constant(ind.type(), 1.0d));
107         diffValueMapping.put(ind, one);
108 
109         // Append differential block parameters to blocks
110         for (Value v : activeSet) {
111             if (v instanceof Block.Parameter ba) {
112                 if (ba != ind) {
113                     // Get the output block builder for the input (declaring) block
114                     Block.Builder b = block.context().getBlock(ba.declaringBlock());
115                     // Add a new block parameter for differential parameter
116                     Block.Parameter dba = b.parameter(ba.type());
117                     // Place in mapping
118                     diffValueMapping.put(ba, dba);
119                 }
120             }
121         }
122     }
123 
124 
125     static final JavaType J_L_MATH = JavaType.type(Math.class);
126     static final FunctionType D_D = CoreType.functionType(DOUBLE, DOUBLE);
127     static final MethodRef J_L_MATH_SIN = MethodRef.method(J_L_MATH, "sin", D_D);
128     static final MethodRef J_L_MATH_COS = MethodRef.method(J_L_MATH, "cos", D_D);
129 
130     Value diffOp(Block.Builder block, Op op) {
131         // Switch on the op, using pattern matching
132         return switch (op) {
133             case JavaOp.NegOp _ -> {
134                 // Copy input operation
135                 block.op(op);
136 
137                 // -diff(expr)
138                 Value a = op.operands().get(0);
139                 Value da = diffValueMapping.getOrDefault(a, zero);
140                 yield block.op(JavaOp.neg(da));
141             }
142             case JavaOp.AddOp _ -> {
143                 // Copy input operation
144                 block.op(op);
145 
146                 // diff(l) + diff(r)
147                 Value lhs = op.operands().get(0);
148                 Value rhs = op.operands().get(1);
149                 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
150                 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
151                 yield block.op(JavaOp.add(dlhs, drhs));
152             }
153             case JavaOp.MulOp _ -> {
154                 // Copy input operation
155                 block.op(op);
156 
157                 // Product rule
158                 // diff(l) * r + l * diff(r)
159                 Value lhs = op.operands().get(0);
160                 Value rhs = op.operands().get(1);
161                 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
162                 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
163                 Value outputLhs = block.context().getValue(lhs);
164                 Value outputRhs = block.context().getValue(rhs);
165                 yield block.op(JavaOp.add(
166                         block.op(JavaOp.mul(dlhs, outputRhs)),
167                         block.op(JavaOp.mul(outputLhs, drhs))));
168             }
169             case CoreOp.ConstantOp _ -> {
170                 // Copy input operation
171                 block.op(op);
172                 // Differential of constant is zero
173                 yield zero;
174             }
175             case JavaOp.InvokeOp c -> {
176                 MethodRef md = c.invokeDescriptor();
177                 String operationName = null;
178                 if (md.refType().equals(J_L_MATH)) {
179                     operationName = md.name();
180                 }
181                 // Differentiate sin(x)
182                 if ("sin".equals(operationName)) {
183                     // Copy input operation
184                     block.op(op);
185 
186                     // Chain rule
187                     // cos(expr) * diff(expr)
188                     Value a = op.operands().get(0);
189                     Value da = diffValueMapping.getOrDefault(a, zero);
190                     Value outputA = block.context().getValue(a);
191                     Op.Result cosx = block.op(JavaOp.invoke(J_L_MATH_COS, outputA));
192                     yield block.op(JavaOp.mul(cosx, da));
193                 } else {
194                     throw new UnsupportedOperationException("Operation not supported: " + op.opName());
195                 }
196             }
197             case CoreOp.ReturnOp _ -> {
198                 // Replace with return of differentiated value
199                 Value a = op.operands().get(0);
200                 Value da = diffValueMapping.getOrDefault(a, zero);
201                 yield block.op(return_(da));
202             }
203             case Op.BlockTerminating _ -> {
204                 // Update with differentiated block arguments
205                 op.successors().forEach(s -> adaptSuccessor(block.context(), s));
206                 yield block.op(op);
207             }
208             default -> throw new UnsupportedOperationException("Operation not supported: " + op.opName());
209         };
210     }
211 
212     void adaptSuccessor(CopyContext cc, Block.Reference from) {
213         List<Value> as = from.arguments().stream()
214                 .filter(activeSet::contains)
215                 .toList();
216         if (!as.isEmpty()) {
217             // Get the successor arguments
218             List<Value> outputArgs = cc.getValues(from.arguments());
219             // Append the differential arguments, if any
220             for (Value a : as) {
221                 Value da = diffValueMapping.get(a);
222                 outputArgs.add(da);
223             }
224 
225             // Map successor with appended arguments
226             Block.Reference to = cc.getBlock(from.targetBlock()).successor(outputArgs);
227             cc.mapSuccessor(from, to);
228         }
229     }
230 
231 }