1 /* 2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import org.testng.Assert; 25 import org.testng.annotations.Test; 26 27 import java.lang.reflect.code.Block; 28 import java.lang.reflect.code.OpTransformer; 29 import java.lang.reflect.code.op.CoreOp; 30 import java.lang.reflect.code.Op; 31 import java.lang.reflect.code.analysis.SSA; 32 import java.lang.reflect.code.bytecode.BytecodeGenerator; 33 import java.lang.reflect.code.interpreter.Interpreter; 34 import java.lang.invoke.MethodHandle; 35 import java.lang.invoke.MethodHandles; 36 import java.lang.reflect.Method; 37 import java.lang.runtime.CodeReflection; 38 import java.util.Optional; 39 import java.util.stream.Stream; 40 41 /* 42 * @test 43 * @enablePreview 44 * @run testng TestForwardAutoDiff 45 */ 46 47 public class TestForwardAutoDiff { 48 static final double PI_4 = Math.PI / 4; 49 50 @Test 51 public void testExpression() throws Throwable { 52 CoreOp.FuncOp f = getFuncOp("f"); 53 f.writeTo(System.out); 54 55 f = SSA.transform(f); 56 f.writeTo(System.out); 57 58 Assert.assertEquals(Interpreter.invoke(f, 0.0, 1.0), f(0.0, 1.0)); 59 Assert.assertEquals(Interpreter.invoke(f, PI_4, PI_4), f(PI_4, PI_4)); 60 61 Block.Parameter x = f.body().entryBlock().parameters().get(0); 62 Block.Parameter y = f.body().entryBlock().parameters().get(1); 63 64 CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x)); 65 dff_dx.writeTo(System.out); 66 MethodHandle dff_dx_mh = generate(dff_dx); 67 Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0)); 68 Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4)); 69 70 CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y)); 71 dff_dy.writeTo(System.out); 72 MethodHandle dff_dy_mh = generate(dff_dy); 73 Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0)); 74 Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4)); 75 } 76 77 @CodeReflection 78 static double f(double x, double y) { 79 return x * (-Math.sin(x * y) + y) * 4.0d; 80 } 81 82 static double df_dx(double x, double y) { 83 return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d; 84 } 85 86 static double df_dy(double x, double y) { 87 return x * (1 - Math.cos(x * y) * x) * 4.0d; 88 } 89 90 @Test 91 public void testControlFlow() throws Throwable { 92 CoreOp.FuncOp f = getFuncOp("fcf"); 93 f.writeTo(System.out); 94 95 f = f.transform(OpTransformer.LOWERING_TRANSFORMER); 96 f.writeTo(System.out); 97 98 f = SSA.transform(f); 99 f.writeTo(System.out); 100 101 Assert.assertEquals(Interpreter.invoke(f, 2.0, 6), fcf(2.0, 6)); 102 Assert.assertEquals(Interpreter.invoke(f, 2.0, 5), fcf(2.0, 5)); 103 Assert.assertEquals(Interpreter.invoke(f, 2.0, 4), fcf(2.0, 4)); 104 105 Block.Parameter x = f.body().entryBlock().parameters().get(0); 106 107 CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x); 108 df_dx.writeTo(System.out); 109 MethodHandle df_dx_mh = generate(df_dx); 110 111 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6)); 112 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5)); 113 Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4)); 114 } 115 116 @CodeReflection 117 static double fcf(/* independent */ double x, int y) { 118 /* dependent */ 119 double o = 1.0; 120 for (int i = 0; i < y; i = i + 1) { 121 if (i > 1) { 122 if (i < 5) { 123 o = o * x; 124 } 125 } 126 } 127 return o; 128 } 129 130 static double dfcf_dx(/* independent */ double x, int y) { 131 double d_o = 0.0; 132 double o = 1.0; 133 for (int i = 0; i < y; i = i + 1) { 134 if (i > 1) { 135 if (i < 5) { 136 d_o = d_o * x + o * 1.0; 137 o = o * x; 138 } 139 } 140 } 141 return d_o; 142 } 143 144 static MethodHandle generate(CoreOp.FuncOp f) { 145 return BytecodeGenerator.generate(MethodHandles.lookup(), f); 146 } 147 148 static CoreOp.FuncOp getFuncOp(String name) { 149 Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods()) 150 .filter(m -> m.getName().equals(name)) 151 .findFirst(); 152 153 Method m = om.get(); 154 return m.getCodeModel().get(); 155 } 156 }