1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import org.testng.Assert; 25 import org.testng.annotations.Test; 26 27 import jdk.incubator.code.Block; 28 import jdk.incubator.code.OpTransformer; 29 import jdk.incubator.code.dialect.core.CoreOp; 30 import jdk.incubator.code.Op; 31 import jdk.incubator.code.analysis.SSA; 32 import jdk.incubator.code.bytecode.BytecodeGenerator; 33 import jdk.incubator.code.interpreter.Interpreter; 34 35 import java.lang.invoke.MethodHandle; 36 import java.lang.invoke.MethodHandles; 37 import java.lang.reflect.Method; 38 import jdk.incubator.code.CodeReflection; 39 import java.util.Optional; 40 import java.util.stream.Stream; 41 42 /* 43 * @test 44 * @modules jdk.incubator.code 45 * @enablePreview 46 * @run testng TestForwardAutoDiff 47 * @run testng/othervm -Dbabylon.ssa=cytron TestForwardAutoDiff 48 */ 49 50 public class TestForwardAutoDiff { 51 static final double PI_4 = Math.PI / 4; 52 53 @Test 54 public void testExpression() throws Throwable { 55 CoreOp.FuncOp f = getFuncOp("f"); 56 System.out.println(f.toText()); 57 58 f = SSA.transform(f); 59 System.out.println(f.toText()); 60 61 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 0.0, 1.0), f(0.0, 1.0)); 62 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, PI_4, PI_4), f(PI_4, PI_4)); 63 64 Block.Parameter x = f.body().entryBlock().parameters().get(0); 65 Block.Parameter y = f.body().entryBlock().parameters().get(1); 66 67 CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x)); 68 System.out.println(dff_dx.toText()); 69 MethodHandle dff_dx_mh = generate(dff_dx); 70 Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0)); 71 Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4)); 72 73 CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y)); 74 System.out.println(dff_dy.toText()); 75 MethodHandle dff_dy_mh = generate(dff_dy); 76 Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0)); 77 Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4)); 78 } 79 80 @CodeReflection 81 static double f(double x, double y) { 82 return x * (-Math.sin(x * y) + y) * 4.0d; 83 } 84 85 static double df_dx(double x, double y) { 86 return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d; 87 } 88 89 static double df_dy(double x, double y) { 90 return x * (1 - Math.cos(x * y) * x) * 4.0d; 91 } 92 93 @Test 94 public void testControlFlow() throws Throwable { 95 CoreOp.FuncOp f = getFuncOp("fcf"); 96 System.out.println(f.toText()); 97 98 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 99 System.out.println(f.toText()); 100 101 f = SSA.transform(f); 102 System.out.println(f.toText()); 103 104 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 6), fcf(2.0, 6)); 105 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 5), fcf(2.0, 5)); 106 Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 4), fcf(2.0, 4)); 107 108 Block.Parameter x = f.body().entryBlock().parameters().get(0); 109 110 CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x); 111 System.out.println(df_dx.toText()); 112 MethodHandle df_dx_mh = generate(df_dx); 113 114 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6)); 115 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5)); 116 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4)); 117 } 118 119 @CodeReflection 120 static double fcf(/* independent */ double x, int y) { 121 /* dependent */ 122 double o = 1.0; 123 for (int i = 0; i < y; i = i + 1) { 124 if (i > 1) { 125 if (i < 5) { 126 o = o * x; 127 } 128 } 129 } 130 return o; 131 } 132 133 static double dfcf_dx(/* independent */ double x, int y) { 134 double d_o = 0.0; 135 double o = 1.0; 136 for (int i = 0; i < y; i = i + 1) { 137 if (i > 1) { 138 if (i < 5) { 139 d_o = d_o * x + o * 1.0; 140 o = o * x; 141 } 142 } 143 } 144 return d_o; 145 } 146 147 static MethodHandle generate(CoreOp.FuncOp f) { 148 return BytecodeGenerator.generate(MethodHandles.lookup(), f); 149 } 150 151 static CoreOp.FuncOp getFuncOp(String name) { 152 Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods()) 153 .filter(m -> m.getName().equals(name)) 154 .findFirst(); 155 156 Method m = om.get(); 157 return Op.ofMethod(m).get(); 158 } 159 }