1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.Block;
25 import jdk.incubator.code.CopyContext;
26 import jdk.incubator.code.Op;
27 import jdk.incubator.code.Value;
28 import jdk.incubator.code.dialect.core.CoreOp;
29 import jdk.incubator.code.dialect.core.CoreType;
30 import jdk.incubator.code.dialect.java.JavaOp;
31 import jdk.incubator.code.dialect.java.MethodRef;
32 import jdk.incubator.code.dialect.core.FunctionType;
33 import jdk.incubator.code.dialect.java.JavaType;
34 import java.util.HashMap;
35 import java.util.List;
36 import java.util.Map;
37 import java.util.Set;
38 import java.util.concurrent.atomic.AtomicBoolean;
39
40 import static jdk.incubator.code.dialect.core.CoreOp.*;
41 import static jdk.incubator.code.dialect.java.JavaType.DOUBLE;
42
43 public final class ForwardDifferentiation {
44 // The function to differentiate
45 final FuncOp fcm;
46 // The independent variable
47 final Block.Parameter ind;
48 // The active set for the independent variable
49 final Set<Value> activeSet;
50 // The map of input value to it's (output) differentiated value
51 final Map<Value, Value> diffValueMapping;
52
53 // The constant value 0.0d
54 // Declared in the (output) function's entry block
55 Value zero;
56
57 private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) {
58 int indI = fcm.body().entryBlock().parameters().indexOf(ind);
59 if (indI == -1) {
60 throw new IllegalArgumentException("Independent argument not defined by function");
61 }
62 this.fcm = fcm;
63 this.ind = ind;
64
65 // Calculate the active set of dependent values for the independent value
66 this.activeSet = ActiveSet.activeSet(fcm, ind);
67 // A mapping of input values to their (output) differentiated values
68 this.diffValueMapping = new HashMap<>();
69 }
70
71 public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) {
72 return new ForwardDifferentiation(fcm, ind).partialDiff();
73 }
74
75 FuncOp partialDiff() {
76 int indI = fcm.body().entryBlock().parameters().indexOf(ind);
77
78 AtomicBoolean first = new AtomicBoolean(true);
79 FuncOp dfcm = fcm.transform(String.format("d%s_darg%d", fcm.funcName(), indI),
80 (block, op) -> {
81 if (first.getAndSet(false)) {
82 // Initialize
83 processBlocks(block);
84 }
85
86 // If the result of the operation is in the active set,
87 // then differentiate it, otherwise copy it
88 if (activeSet.contains(op.result())) {
89 Value dor = diffOp(block, op);
90 // Map the input result to its (output) differentiated result
91 // so that it can be used when differentiating subsequent operations
92 diffValueMapping.put(op.result(), dor);
93 } else {
94 block.op(op);
95 }
96 return block;
97 });
98
99 return dfcm;
100 }
101
102 void processBlocks(Block.Builder block) {
103 // Declare constants at start
104 zero = block.op(constant(ind.type(), 0.0d));
105 // The differential of ind is 1
106 Value one = block.op(constant(ind.type(), 1.0d));
107 diffValueMapping.put(ind, one);
108
109 // Append differential block parameters to blocks
110 for (Value v : activeSet) {
111 if (v instanceof Block.Parameter ba) {
112 if (ba != ind) {
113 // Get the output block builder for the input (declaring) block
114 Block.Builder b = block.context().getBlock(ba.declaringBlock());
115 // Add a new block parameter for differential parameter
116 Block.Parameter dba = b.parameter(ba.type());
117 // Place in mapping
118 diffValueMapping.put(ba, dba);
119 }
120 }
121 }
122 }
123
124
125 static final JavaType J_L_MATH = JavaType.type(Math.class);
126 static final FunctionType D_D = CoreType.functionType(DOUBLE, DOUBLE);
127 static final MethodRef J_L_MATH_SIN = MethodRef.method(J_L_MATH, "sin", D_D);
128 static final MethodRef J_L_MATH_COS = MethodRef.method(J_L_MATH, "cos", D_D);
129
130 Value diffOp(Block.Builder block, Op op) {
131 // Switch on the op, using pattern matching
132 return switch (op) {
133 case JavaOp.NegOp _ -> {
134 // Copy input operation
135 block.op(op);
136
137 // -diff(expr)
138 Value a = op.operands().get(0);
139 Value da = diffValueMapping.getOrDefault(a, zero);
140 yield block.op(JavaOp.neg(da));
141 }
142 case JavaOp.AddOp _ -> {
143 // Copy input operation
144 block.op(op);
145
146 // diff(l) + diff(r)
147 Value lhs = op.operands().get(0);
148 Value rhs = op.operands().get(1);
149 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
150 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
151 yield block.op(JavaOp.add(dlhs, drhs));
152 }
153 case JavaOp.MulOp _ -> {
154 // Copy input operation
155 block.op(op);
156
157 // Product rule
158 // diff(l) * r + l * diff(r)
159 Value lhs = op.operands().get(0);
160 Value rhs = op.operands().get(1);
161 Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
162 Value drhs = diffValueMapping.getOrDefault(rhs, zero);
163 Value outputLhs = block.context().getValue(lhs);
164 Value outputRhs = block.context().getValue(rhs);
165 yield block.op(JavaOp.add(
166 block.op(JavaOp.mul(dlhs, outputRhs)),
167 block.op(JavaOp.mul(outputLhs, drhs))));
168 }
169 case CoreOp.ConstantOp _ -> {
170 // Copy input operation
171 block.op(op);
172 // Differential of constant is zero
173 yield zero;
174 }
175 case JavaOp.InvokeOp c -> {
176 MethodRef md = c.invokeDescriptor();
177 String operationName = null;
178 if (md.refType().equals(J_L_MATH)) {
179 operationName = md.name();
180 }
181 // Differentiate sin(x)
182 if ("sin".equals(operationName)) {
183 // Copy input operation
184 block.op(op);
185
186 // Chain rule
187 // cos(expr) * diff(expr)
188 Value a = op.operands().get(0);
189 Value da = diffValueMapping.getOrDefault(a, zero);
190 Value outputA = block.context().getValue(a);
191 Op.Result cosx = block.op(JavaOp.invoke(J_L_MATH_COS, outputA));
192 yield block.op(JavaOp.mul(cosx, da));
193 } else {
194 throw new UnsupportedOperationException("Operation not supported: " + op);
195 }
196 }
197 case CoreOp.ReturnOp _ -> {
198 // Replace with return of differentiated value
199 Value a = op.operands().get(0);
200 Value da = diffValueMapping.getOrDefault(a, zero);
201 yield block.op(return_(da));
202 }
203 case Op.BlockTerminating _ -> {
204 // Update with differentiated block arguments
205 op.successors().forEach(s -> adaptSuccessor(block.context(), s));
206 yield block.op(op);
207 }
208 default -> throw new UnsupportedOperationException("Operation not supported: " + op);
209 };
210 }
211
212 void adaptSuccessor(CopyContext cc, Block.Reference from) {
213 List<Value> as = from.arguments().stream()
214 .filter(activeSet::contains)
215 .toList();
216 if (!as.isEmpty()) {
217 // Get the successor arguments
218 List<Value> outputArgs = cc.getValues(from.arguments());
219 // Append the differential arguments, if any
220 for (Value a : as) {
221 Value da = diffValueMapping.get(a);
222 outputArgs.add(da);
223 }
224
225 // Map successor with appended arguments
226 Block.Reference to = cc.getBlock(from.targetBlock()).successor(outputArgs);
227 cc.mapSuccessor(from, to);
228 }
229 }
230
231 }