1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import jdk.incubator.code.Block; 25 import jdk.incubator.code.CopyContext; 26 import jdk.incubator.code.Op; 27 import jdk.incubator.code.Value; 28 import jdk.incubator.code.dialect.core.CoreOp; 29 import jdk.incubator.code.dialect.core.CoreType; 30 import jdk.incubator.code.dialect.java.JavaOp; 31 import jdk.incubator.code.dialect.java.MethodRef; 32 import jdk.incubator.code.dialect.core.FunctionType; 33 import jdk.incubator.code.dialect.java.JavaType; 34 import java.util.HashMap; 35 import java.util.List; 36 import java.util.Map; 37 import java.util.Set; 38 import java.util.concurrent.atomic.AtomicBoolean; 39 40 import static jdk.incubator.code.dialect.core.CoreOp.*; 41 import static jdk.incubator.code.dialect.java.JavaType.DOUBLE; 42 43 public final class ForwardDifferentiation { 44 // The function to differentiate 45 final FuncOp fcm; 46 // The independent variable 47 final Block.Parameter ind; 48 // The active set for the independent variable 49 final Set<Value> activeSet; 50 // The map of input value to it's (output) differentiated value 51 final Map<Value, Value> diffValueMapping; 52 53 // The constant value 0.0d 54 // Declared in the (output) function's entry block 55 Value zero; 56 57 private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) { 58 int indI = fcm.body().entryBlock().parameters().indexOf(ind); 59 if (indI == -1) { 60 throw new IllegalArgumentException("Independent argument not defined by function"); 61 } 62 this.fcm = fcm; 63 this.ind = ind; 64 65 // Calculate the active set of dependent values for the independent value 66 this.activeSet = ActiveSet.activeSet(fcm, ind); 67 // A mapping of input values to their (output) differentiated values 68 this.diffValueMapping = new HashMap<>(); 69 } 70 71 public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) { 72 return new ForwardDifferentiation(fcm, ind).partialDiff(); 73 } 74 75 FuncOp partialDiff() { 76 int indI = fcm.body().entryBlock().parameters().indexOf(ind); 77 78 AtomicBoolean first = new AtomicBoolean(true); 79 FuncOp dfcm = fcm.transform(String.format("d%s_darg%d", fcm.funcName(), indI), 80 (block, op) -> { 81 if (first.getAndSet(false)) { 82 // Initialize 83 processBlocks(block); 84 } 85 86 // If the result of the operation is in the active set, 87 // then differentiate it, otherwise copy it 88 if (activeSet.contains(op.result())) { 89 Value dor = diffOp(block, op); 90 // Map the input result to its (output) differentiated result 91 // so that it can be used when differentiating subsequent operations 92 diffValueMapping.put(op.result(), dor); 93 } else { 94 block.apply(op); 95 } 96 return block; 97 }); 98 99 return dfcm; 100 } 101 102 void processBlocks(Block.Builder block) { 103 // Declare constants at start 104 zero = block.op(constant(ind.type(), 0.0d)); 105 // The differential of ind is 1 106 Value one = block.op(constant(ind.type(), 1.0d)); 107 diffValueMapping.put(ind, one); 108 109 // Append differential block parameters to blocks 110 for (Value v : activeSet) { 111 if (v instanceof Block.Parameter ba) { 112 if (ba != ind) { 113 // Get the output block builder for the input (declaring) block 114 Block.Builder b = block.context().getBlock(ba.declaringBlock()); 115 // Add a new block parameter for differential parameter 116 Block.Parameter dba = b.parameter(ba.type()); 117 // Place in mapping 118 diffValueMapping.put(ba, dba); 119 } 120 } 121 } 122 } 123 124 125 static final JavaType J_L_MATH = JavaType.type(Math.class); 126 static final FunctionType D_D = CoreType.functionType(DOUBLE, DOUBLE); 127 static final MethodRef J_L_MATH_SIN = MethodRef.method(J_L_MATH, "sin", D_D); 128 static final MethodRef J_L_MATH_COS = MethodRef.method(J_L_MATH, "cos", D_D); 129 130 Value diffOp(Block.Builder block, Op op) { 131 // Switch on the op, using pattern matching 132 return switch (op) { 133 case JavaOp.NegOp _ -> { 134 // Copy input operation 135 block.op(op); 136 137 // -diff(expr) 138 Value a = op.operands().get(0); 139 Value da = diffValueMapping.getOrDefault(a, zero); 140 yield block.op(JavaOp.neg(da)); 141 } 142 case JavaOp.AddOp _ -> { 143 // Copy input operation 144 block.op(op); 145 146 // diff(l) + diff(r) 147 Value lhs = op.operands().get(0); 148 Value rhs = op.operands().get(1); 149 Value dlhs = diffValueMapping.getOrDefault(lhs, zero); 150 Value drhs = diffValueMapping.getOrDefault(rhs, zero); 151 yield block.op(JavaOp.add(dlhs, drhs)); 152 } 153 case JavaOp.MulOp _ -> { 154 // Copy input operation 155 block.op(op); 156 157 // Product rule 158 // diff(l) * r + l * diff(r) 159 Value lhs = op.operands().get(0); 160 Value rhs = op.operands().get(1); 161 Value dlhs = diffValueMapping.getOrDefault(lhs, zero); 162 Value drhs = diffValueMapping.getOrDefault(rhs, zero); 163 Value outputLhs = block.context().getValue(lhs); 164 Value outputRhs = block.context().getValue(rhs); 165 yield block.op(JavaOp.add( 166 block.op(JavaOp.mul(dlhs, outputRhs)), 167 block.op(JavaOp.mul(outputLhs, drhs)))); 168 } 169 case CoreOp.ConstantOp _ -> { 170 // Copy input operation 171 block.op(op); 172 // Differential of constant is zero 173 yield zero; 174 } 175 case JavaOp.InvokeOp c -> { 176 MethodRef md = c.invokeDescriptor(); 177 String operationName = null; 178 if (md.refType().equals(J_L_MATH)) { 179 operationName = md.name(); 180 } 181 // Differentiate sin(x) 182 if ("sin".equals(operationName)) { 183 // Copy input operation 184 block.op(op); 185 186 // Chain rule 187 // cos(expr) * diff(expr) 188 Value a = op.operands().get(0); 189 Value da = diffValueMapping.getOrDefault(a, zero); 190 Value outputA = block.context().getValue(a); 191 Op.Result cosx = block.op(JavaOp.invoke(J_L_MATH_COS, outputA)); 192 yield block.op(JavaOp.mul(cosx, da)); 193 } else { 194 throw new UnsupportedOperationException("Operation not supported: " + op.opName()); 195 } 196 } 197 case CoreOp.ReturnOp _ -> { 198 // Replace with return of differentiated value 199 Value a = op.operands().get(0); 200 Value da = diffValueMapping.getOrDefault(a, zero); 201 yield block.op(return_(da)); 202 } 203 case Op.BlockTerminating _ -> { 204 // Update with differentiated block arguments 205 op.successors().forEach(s -> adaptSuccessor(block.context(), s)); 206 yield block.op(op); 207 } 208 default -> throw new UnsupportedOperationException("Operation not supported: " + op.opName()); 209 }; 210 } 211 212 void adaptSuccessor(CopyContext cc, Block.Reference from) { 213 List<Value> as = from.arguments().stream() 214 .filter(activeSet::contains) 215 .toList(); 216 if (!as.isEmpty()) { 217 // Get the successor arguments 218 List<Value> outputArgs = cc.getValues(from.arguments()); 219 // Append the differential arguments, if any 220 for (Value a : as) { 221 Value da = diffValueMapping.get(a); 222 outputArgs.add(da); 223 } 224 225 // Map successor with appended arguments 226 Block.Reference to = cc.getBlock(from.targetBlock()).successor(outputArgs); 227 cc.mapSuccessor(from, to); 228 } 229 } 230 231 }