1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import java.lang.reflect.code.Block;
 25 import java.lang.reflect.code.CopyContext;
 26 import java.lang.reflect.code.Op;
 27 import java.lang.reflect.code.Value;
 28 import java.lang.reflect.code.op.CoreOp;
 29 import java.lang.reflect.code.type.MethodRef;
 30 import java.lang.reflect.code.type.FunctionType;
 31 import java.lang.reflect.code.type.JavaType;
 32 import java.util.HashMap;
 33 import java.util.List;
 34 import java.util.Map;
 35 import java.util.Set;
 36 import java.util.concurrent.atomic.AtomicBoolean;
 37 
 38 import static java.lang.reflect.code.op.CoreOp.*;
 39 import static java.lang.reflect.code.type.JavaType.DOUBLE;
 40 
 41 public final class ForwardDifferentiation {
 42     // The function to differentiate
 43     final FuncOp fcm;
 44     // The independent variable
 45     final Block.Parameter ind;
 46     // The active set for the independent variable
 47     final Set<Value> activeSet;
 48     // The map of input value to it's (output) differentiated value
 49     final Map<Value, Value> diffValueMapping;
 50 
 51     // The constant value 0.0d
 52     // Declared in the (output) function's entry block
 53     Value zero;
 54 
 55     private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) {
 56         int indI = fcm.body().entryBlock().parameters().indexOf(ind);
 57         if (indI == -1) {
 58             throw new IllegalArgumentException("Independent argument not defined by function");
 59         }
 60         this.fcm = fcm;
 61         this.ind = ind;
 62 
 63         // Calculate the active set of dependent values for the independent value
 64         this.activeSet = ActiveSet.activeSet(fcm, ind);
 65         // A mapping of input values to their (output) differentiated values
 66         this.diffValueMapping = new HashMap<>();
 67     }
 68 
 69     public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) {
 70         return new ForwardDifferentiation(fcm, ind).partialDiff();
 71     }
 72 
 73     FuncOp partialDiff() {
 74         int indI = fcm.body().entryBlock().parameters().indexOf(ind);
 75 
 76         AtomicBoolean first = new AtomicBoolean(true);
 77         FuncOp dfcm = fcm.transform(String.format("d%s_darg%d", fcm.funcName(), indI),
 78                 (block, op) -> {
 79                     if (first.getAndSet(false)) {
 80                         // Initialize
 81                         processBlocks(block);
 82                     }
 83 
 84                     // If the result of the operation is in the active set,
 85                     // then differentiate it, otherwise copy it
 86                     if (activeSet.contains(op.result())) {
 87                         Value dor = diffOp(block, op);
 88                         // Map the input result to its (output) differentiated result
 89                         // so that it can be used when differentiating subsequent operations
 90                         diffValueMapping.put(op.result(), dor);
 91                     } else {
 92                         block.apply(op);
 93                     }
 94                     return block;
 95                 });
 96 
 97         return dfcm;
 98     }
 99 
100     void processBlocks(Block.Builder block) {
101         // Declare constants at start
102         zero = block.op(constant(ind.type(), 0.0d));
103         // The differential of ind is 1
104         Value one = block.op(constant(ind.type(), 1.0d));
105         diffValueMapping.put(ind, one);
106 
107         // Append differential block parameters to blocks
108         for (Value v : activeSet) {
109             if (v instanceof Block.Parameter ba) {
110                 if (ba != ind) {
111                     // Get the output block builder for the input (declaring) block
112                     Block.Builder b = block.context().getBlock(ba.declaringBlock());
113                     // Add a new block parameter for differential parameter
114                     Block.Parameter dba = b.parameter(ba.type());
115                     // Place in mapping
116                     diffValueMapping.put(ba, dba);
117                 }
118             }
119         }
120     }
121 
122 
123     static final JavaType J_L_MATH = JavaType.type(Math.class);
124     static final FunctionType D_D = FunctionType.functionType(DOUBLE, DOUBLE);
125     static final MethodRef J_L_MATH_SIN = MethodRef.method(J_L_MATH, "sin", D_D);
126     static final MethodRef J_L_MATH_COS = MethodRef.method(J_L_MATH, "cos", D_D);
127 
128     Value diffOp(Block.Builder block, Op op) {
129         // Switch on the op, using pattern matching
130         return switch (op) {
131             case CoreOp.NegOp _ -> {
132                 // Copy input operation
133                 block.op(op);
134 
135                 // -diff(expr)
136                 Value a = op.operands().get(0);
137                 Value da = diffValueMapping.getOrDefault(a, zero);
138                 yield block.op(neg(da));
139             }
140             case CoreOp.AddOp _ -> {
141                 // Copy input operation
142                 block.op(op);
143 
144                 // diff(l) + diff(r)
145                 Value lhs = op.operands().get(0);
146                 Value rhs = op.operands().get(1);
147                 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
148                 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
149                 yield block.op(add(dlhs, drhs));
150             }
151             case CoreOp.MulOp _ -> {
152                 // Copy input operation
153                 block.op(op);
154 
155                 // Product rule
156                 // diff(l) * r + l * diff(r)
157                 Value lhs = op.operands().get(0);
158                 Value rhs = op.operands().get(1);
159                 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
160                 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
161                 Value outputLhs = block.context().getValue(lhs);
162                 Value outputRhs = block.context().getValue(rhs);
163                 yield block.op(add(
164                         block.op(mul(dlhs, outputRhs)),
165                         block.op(mul(outputLhs, drhs))));
166             }
167             case CoreOp.ConstantOp _ -> {
168                 // Copy input operation
169                 block.op(op);
170                 // Differential of constant is zero
171                 yield zero;
172             }
173             case CoreOp.InvokeOp c -> {
174                 MethodRef md = c.invokeDescriptor();
175                 String operationName = null;
176                 if (md.refType().equals(J_L_MATH)) {
177                     operationName = md.name();
178                 }
179                 // Differentiate sin(x)
180                 if ("sin".equals(operationName)) {
181                     // Copy input operation
182                     block.op(op);
183 
184                     // Chain rule
185                     // cos(expr) * diff(expr)
186                     Value a = op.operands().get(0);
187                     Value da = diffValueMapping.getOrDefault(a, zero);
188                     Value outputA = block.context().getValue(a);
189                     Op.Result cosx = block.op(invoke(J_L_MATH_COS, outputA));
190                     yield block.op(mul(cosx, da));
191                 } else {
192                     throw new UnsupportedOperationException("Operation not supported: " + op.opName());
193                 }
194             }
195             case CoreOp.ReturnOp _ -> {
196                 // Replace with return of differentiated value
197                 Value a = op.operands().get(0);
198                 Value da = diffValueMapping.getOrDefault(a, zero);
199                 yield block.op(_return(da));
200             }
201             case Op.BlockTerminating _ -> {
202                 // Update with differentiated block arguments
203                 op.successors().forEach(s -> adaptSuccessor(block.context(), s));
204                 yield block.op(op);
205             }
206             default -> throw new UnsupportedOperationException("Operation not supported: " + op.opName());
207         };
208     }
209 
210     void adaptSuccessor(CopyContext cc, Block.Reference from) {
211         List<Value> as = from.arguments().stream()
212                 .filter(activeSet::contains)
213                 .toList();
214         if (!as.isEmpty()) {
215             // Get the successor arguments
216             List<Value> outputArgs = cc.getValues(from.arguments());
217             // Append the differential arguments, if any
218             for (Value a : as) {
219                 Value da = diffValueMapping.get(a);
220                 outputArgs.add(da);
221             }
222 
223             // Map successor with appended arguments
224             Block.Reference to = cc.getBlock(from.targetBlock()).successor(outputArgs);
225             cc.mapSuccessor(from, to);
226         }
227     }
228 
229 }