1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.Block; 25 import jdk.incubator.code.CopyContext; 26 import jdk.incubator.code.Op; 27 import jdk.incubator.code.Value; 28 import jdk.incubator.code.op.CoreOp; 29 import jdk.incubator.code.type.MethodRef; 30 import jdk.incubator.code.type.FunctionType; 31 import jdk.incubator.code.type.JavaType; 32 import java.util.HashMap; 33 import java.util.List; 34 import java.util.Map; 35 import java.util.Set; 36 import java.util.concurrent.atomic.AtomicBoolean; 37 38 import static jdk.incubator.code.op.CoreOp.*; 39 import static jdk.incubator.code.type.JavaType.DOUBLE; 40 41 public final class ForwardDifferentiation { 42 // The function to differentiate 43 final FuncOp fcm; 44 // The independent variable 45 final Block.Parameter ind; 46 // The active set for the independent variable 47 final Set<Value> activeSet; 48 // The map of input value to it's (output) differentiated value 49 final Map<Value, Value> diffValueMapping; 50 51 // The constant value 0.0d 52 // Declared in the (output) function's entry block 53 Value zero; 54 55 private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) { 56 int indI = fcm.body().entryBlock().parameters().indexOf(ind); 57 if (indI == -1) { 58 throw new IllegalArgumentException("Independent argument not defined by function"); 59 } 60 this.fcm = fcm; 61 this.ind = ind; 62 63 // Calculate the active set of dependent values for the independent value 64 this.activeSet = ActiveSet.activeSet(fcm, ind); 65 // A mapping of input values to their (output) differentiated values 66 this.diffValueMapping = new HashMap<>(); 67 } 68 69 public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) { 70 return new ForwardDifferentiation(fcm, ind).partialDiff(); 71 } 72 73 FuncOp partialDiff() { 74 int indI = fcm.body().entryBlock().parameters().indexOf(ind); 75 76 AtomicBoolean first = new AtomicBoolean(true); 77 FuncOp dfcm = fcm.transform(String.format("d%s_darg%d", fcm.funcName(), indI), 78 (block, op) -> { 79 if (first.getAndSet(false)) { 80 // Initialize 81 processBlocks(block); 82 } 83 84 // If the result of the operation is in the active set, 85 // then differentiate it, otherwise copy it 86 if (activeSet.contains(op.result())) { 87 Value dor = diffOp(block, op); 88 // Map the input result to its (output) differentiated result 89 // so that it can be used when differentiating subsequent operations 90 diffValueMapping.put(op.result(), dor); 91 } else { 92 block.apply(op); 93 } 94 return block; 95 }); 96 97 return dfcm; 98 } 99 100 void processBlocks(Block.Builder block) { 101 // Declare constants at start 102 zero = block.op(constant(ind.type(), 0.0d)); 103 // The differential of ind is 1 104 Value one = block.op(constant(ind.type(), 1.0d)); 105 diffValueMapping.put(ind, one); 106 107 // Append differential block parameters to blocks 108 for (Value v : activeSet) { 109 if (v instanceof Block.Parameter ba) { 110 if (ba != ind) { 111 // Get the output block builder for the input (declaring) block 112 Block.Builder b = block.context().getBlock(ba.declaringBlock()); 113 // Add a new block parameter for differential parameter 114 Block.Parameter dba = b.parameter(ba.type()); 115 // Place in mapping 116 diffValueMapping.put(ba, dba); 117 } 118 } 119 } 120 } 121 122 123 static final JavaType J_L_MATH = JavaType.type(Math.class); 124 static final FunctionType D_D = FunctionType.functionType(DOUBLE, DOUBLE); 125 static final MethodRef J_L_MATH_SIN = MethodRef.method(J_L_MATH, "sin", D_D); 126 static final MethodRef J_L_MATH_COS = MethodRef.method(J_L_MATH, "cos", D_D); 127 128 Value diffOp(Block.Builder block, Op op) { 129 // Switch on the op, using pattern matching 130 return switch (op) { 131 case CoreOp.NegOp _ -> { 132 // Copy input operation 133 block.op(op); 134 135 // -diff(expr) 136 Value a = op.operands().get(0); 137 Value da = diffValueMapping.getOrDefault(a, zero); 138 yield block.op(neg(da)); 139 } 140 case CoreOp.AddOp _ -> { 141 // Copy input operation 142 block.op(op); 143 144 // diff(l) + diff(r) 145 Value lhs = op.operands().get(0); 146 Value rhs = op.operands().get(1); 147 Value dlhs = diffValueMapping.getOrDefault(lhs, zero); 148 Value drhs = diffValueMapping.getOrDefault(rhs, zero); 149 yield block.op(add(dlhs, drhs)); 150 } 151 case CoreOp.MulOp _ -> { 152 // Copy input operation 153 block.op(op); 154 155 // Product rule 156 // diff(l) * r + l * diff(r) 157 Value lhs = op.operands().get(0); 158 Value rhs = op.operands().get(1); 159 Value dlhs = diffValueMapping.getOrDefault(lhs, zero); 160 Value drhs = diffValueMapping.getOrDefault(rhs, zero); 161 Value outputLhs = block.context().getValue(lhs); 162 Value outputRhs = block.context().getValue(rhs); 163 yield block.op(add( 164 block.op(mul(dlhs, outputRhs)), 165 block.op(mul(outputLhs, drhs)))); 166 } 167 case CoreOp.ConstantOp _ -> { 168 // Copy input operation 169 block.op(op); 170 // Differential of constant is zero 171 yield zero; 172 } 173 case CoreOp.InvokeOp c -> { 174 MethodRef md = c.invokeDescriptor(); 175 String operationName = null; 176 if (md.refType().equals(J_L_MATH)) { 177 operationName = md.name(); 178 } 179 // Differentiate sin(x) 180 if ("sin".equals(operationName)) { 181 // Copy input operation 182 block.op(op); 183 184 // Chain rule 185 // cos(expr) * diff(expr) 186 Value a = op.operands().get(0); 187 Value da = diffValueMapping.getOrDefault(a, zero); 188 Value outputA = block.context().getValue(a); 189 Op.Result cosx = block.op(invoke(J_L_MATH_COS, outputA)); 190 yield block.op(mul(cosx, da)); 191 } else { 192 throw new UnsupportedOperationException("Operation not supported: " + op.opName()); 193 } 194 } 195 case CoreOp.ReturnOp _ -> { 196 // Replace with return of differentiated value 197 Value a = op.operands().get(0); 198 Value da = diffValueMapping.getOrDefault(a, zero); 199 yield block.op(_return(da)); 200 } 201 case Op.BlockTerminating _ -> { 202 // Update with differentiated block arguments 203 op.successors().forEach(s -> adaptSuccessor(block.context(), s)); 204 yield block.op(op); 205 } 206 default -> throw new UnsupportedOperationException("Operation not supported: " + op.opName()); 207 }; 208 } 209 210 void adaptSuccessor(CopyContext cc, Block.Reference from) { 211 List<Value> as = from.arguments().stream() 212 .filter(activeSet::contains) 213 .toList(); 214 if (!as.isEmpty()) { 215 // Get the successor arguments 216 List<Value> outputArgs = cc.getValues(from.arguments()); 217 // Append the differential arguments, if any 218 for (Value a : as) { 219 Value da = diffValueMapping.get(a); 220 outputArgs.add(da); 221 } 222 223 // Map successor with appended arguments 224 Block.Reference to = cc.getBlock(from.targetBlock()).successor(outputArgs); 225 cc.mapSuccessor(from, to); 226 } 227 } 228 229 }