1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import jdk.incubator.code.Block;
 28 import jdk.incubator.code.OpTransformer;
 29 import jdk.incubator.code.dialect.core.CoreOp;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.analysis.SSA;
 32 import jdk.incubator.code.bytecode.BytecodeGenerator;
 33 import jdk.incubator.code.interpreter.Interpreter;
 34 
 35 import java.lang.invoke.MethodHandle;
 36 import java.lang.invoke.MethodHandles;
 37 import java.lang.reflect.Method;
 38 import jdk.incubator.code.CodeReflection;
 39 import java.util.Optional;
 40 import java.util.stream.Stream;
 41 
 42 /*
 43  * @test
 44  * @modules jdk.incubator.code
 45  * @enablePreview
 46  * @run testng TestForwardAutoDiff
 47  * @run testng/othervm -Dbabylon.ssa=cytron TestForwardAutoDiff
 48  */
 49 
 50 public class TestForwardAutoDiff {
 51     static final double PI_4 = Math.PI / 4;
 52 
 53     @Test
 54     public void testExpression() throws Throwable {
 55         CoreOp.FuncOp f = getFuncOp("f");
 56         System.out.println(f.toText());
 57 
 58         f = SSA.transform(f);
 59         System.out.println(f.toText());
 60 
 61         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 0.0, 1.0), f(0.0, 1.0));
 62         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, PI_4, PI_4), f(PI_4, PI_4));
 63 
 64         Block.Parameter x = f.body().entryBlock().parameters().get(0);
 65         Block.Parameter y = f.body().entryBlock().parameters().get(1);
 66 
 67         CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x));
 68         System.out.println(dff_dx.toText());
 69         MethodHandle dff_dx_mh = generate(dff_dx);
 70         Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0));
 71         Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4));
 72 
 73         CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y));
 74         System.out.println(dff_dy.toText());
 75         MethodHandle dff_dy_mh = generate(dff_dy);
 76         Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0));
 77         Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4));
 78     }
 79 
 80     @CodeReflection
 81     static double f(double x, double y) {
 82         return x * (-Math.sin(x * y) + y) * 4.0d;
 83     }
 84 
 85     static double df_dx(double x, double y) {
 86         return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
 87     }
 88 
 89     static double df_dy(double x, double y) {
 90         return x * (1 - Math.cos(x * y) * x) * 4.0d;
 91     }
 92 
 93     @Test
 94     public void testControlFlow() throws Throwable {
 95         CoreOp.FuncOp f = getFuncOp("fcf");
 96         System.out.println(f.toText());
 97 
 98         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 99         System.out.println(f.toText());
100 
101         f = SSA.transform(f);
102         System.out.println(f.toText());
103 
104         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 6), fcf(2.0, 6));
105         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 5), fcf(2.0, 5));
106         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 4), fcf(2.0, 4));
107 
108         Block.Parameter x = f.body().entryBlock().parameters().get(0);
109 
110         CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x);
111         System.out.println(df_dx.toText());
112         MethodHandle df_dx_mh = generate(df_dx);
113 
114         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6));
115         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5));
116         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4));
117     }
118 
119     @CodeReflection
120     static double fcf(/* independent */ double x, int y) {
121         /* dependent */
122         double o = 1.0;
123         for (int i = 0; i < y; i = i + 1) {
124             if (i > 1) {
125                 if (i < 5) {
126                     o = o * x;
127                 }
128             }
129         }
130         return o;
131     }
132 
133     static double dfcf_dx(/* independent */ double x, int y) {
134         double d_o = 0.0;
135         double o = 1.0;
136         for (int i = 0; i < y; i = i + 1) {
137             if (i > 1) {
138                 if (i < 5) {
139                     d_o = d_o * x + o * 1.0;
140                     o = o * x;
141                 }
142             }
143         }
144         return d_o;
145     }
146 
147     static MethodHandle generate(CoreOp.FuncOp f) {
148         return BytecodeGenerator.generate(MethodHandles.lookup(), f);
149     }
150 
151     static CoreOp.FuncOp getFuncOp(String name) {
152         Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods())
153                 .filter(m -> m.getName().equals(name))
154                 .findFirst();
155 
156         Method m = om.get();
157         return Op.ofMethod(m).get();
158     }
159 }