1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import org.testng.Assert; 25 import org.testng.annotations.Test; 26 27 import jdk.incubator.code.Block; 28 import jdk.incubator.code.OpTransformer; 29 import jdk.incubator.code.op.CoreOp; 30 import jdk.incubator.code.Op; 31 import jdk.incubator.code.analysis.SSA; 32 import jdk.incubator.code.bytecode.BytecodeGenerator; 33 import jdk.incubator.code.interpreter.Interpreter; 34 import java.lang.invoke.MethodHandle; 35 import java.lang.invoke.MethodHandles; 36 import java.lang.reflect.Method; 37 import jdk.incubator.code.CodeReflection; 38 import java.util.Optional; 39 import java.util.stream.Stream; 40 41 /* 42 * @test 43 * @modules jdk.incubator.code 44 * @enablePreview 45 * @run testng TestForwardAutoDiff 46 * @run testng/othervm -Dbabylon.ssa=cytron TestForwardAutoDiff 47 */ 48 49 public class TestForwardAutoDiff { 50 static final double PI_4 = Math.PI / 4; 51 52 @Test 53 public void testExpression() throws Throwable { 54 CoreOp.FuncOp f = getFuncOp("f"); 55 f.writeTo(System.out); 56 57 f = SSA.transform(f); 58 f.writeTo(System.out); 59 60 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 0.0, 1.0), f(0.0, 1.0)); 61 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, PI_4, PI_4), f(PI_4, PI_4)); 62 63 Block.Parameter x = f.body().entryBlock().parameters().get(0); 64 Block.Parameter y = f.body().entryBlock().parameters().get(1); 65 66 CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x)); 67 dff_dx.writeTo(System.out); 68 MethodHandle dff_dx_mh = generate(dff_dx); 69 Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0)); 70 Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4)); 71 72 CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y)); 73 dff_dy.writeTo(System.out); 74 MethodHandle dff_dy_mh = generate(dff_dy); 75 Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0)); 76 Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4)); 77 } 78 79 @CodeReflection 80 static double f(double x, double y) { 81 return x * (-Math.sin(x * y) + y) * 4.0d; 82 } 83 84 static double df_dx(double x, double y) { 85 return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d; 86 } 87 88 static double df_dy(double x, double y) { 89 return x * (1 - Math.cos(x * y) * x) * 4.0d; 90 } 91 92 @Test 93 public void testControlFlow() throws Throwable { 94 CoreOp.FuncOp f = getFuncOp("fcf"); 95 f.writeTo(System.out); 96 97 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 98 f.writeTo(System.out); 99 100 f = SSA.transform(f); 101 f.writeTo(System.out); 102 103 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 6), fcf(2.0, 6)); 104 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 5), fcf(2.0, 5)); 105 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 4), fcf(2.0, 4)); 106 107 Block.Parameter x = f.body().entryBlock().parameters().get(0); 108 109 CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x); 110 df_dx.writeTo(System.out); 111 MethodHandle df_dx_mh = generate(df_dx); 112 113 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6)); 114 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5)); 115 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4)); 116 } 117 118 @CodeReflection 119 static double fcf(/* independent */ double x, int y) { 120 /* dependent */ 121 double o = 1.0; 122 for (int i = 0; i < y; i = i + 1) { 123 if (i > 1) { 124 if (i < 5) { 125 o = o * x; 126 } 127 } 128 } 129 return o; 130 } 131 132 static double dfcf_dx(/* independent */ double x, int y) { 133 double d_o = 0.0; 134 double o = 1.0; 135 for (int i = 0; i < y; i = i + 1) { 136 if (i > 1) { 137 if (i < 5) { 138 d_o = d_o * x + o * 1.0; 139 o = o * x; 140 } 141 } 142 } 143 return d_o; 144 } 145 146 static MethodHandle generate(CoreOp.FuncOp f) { 147 return BytecodeGenerator.generate(MethodHandles.lookup(), f); 148 } 149 150 static CoreOp.FuncOp getFuncOp(String name) { 151 Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods()) 152 .filter(m -> m.getName().equals(name)) 153 .findFirst(); 154 155 Method m = om.get(); 156 return Op.ofMethod(m).get(); 157 } 158 }