1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import jdk.incubator.code.Block;
 28 import jdk.incubator.code.OpTransformer;
 29 import jdk.incubator.code.op.CoreOp;
 30 import jdk.incubator.code.Op;
 31 import jdk.incubator.code.analysis.SSA;
 32 import jdk.incubator.code.bytecode.BytecodeGenerator;
 33 import jdk.incubator.code.interpreter.Interpreter;
 34 import java.lang.invoke.MethodHandle;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.reflect.Method;
 37 import jdk.incubator.code.CodeReflection;
 38 import java.util.Optional;
 39 import java.util.stream.Stream;
 40 
 41 /*
 42  * @test
 43  * @modules jdk.incubator.code
 44  * @enablePreview
 45  * @run testng TestForwardAutoDiff
 46  * @run testng/othervm -Dbabylon.ssa=cytron TestForwardAutoDiff
 47  */
 48 
 49 public class TestForwardAutoDiff {
 50     static final double PI_4 = Math.PI / 4;
 51 
 52     @Test
 53     public void testExpression() throws Throwable {
 54         CoreOp.FuncOp f = getFuncOp("f");
 55         f.writeTo(System.out);
 56 
 57         f = SSA.transform(f);
 58         f.writeTo(System.out);
 59 
 60         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 0.0, 1.0), f(0.0, 1.0));
 61         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, PI_4, PI_4), f(PI_4, PI_4));
 62 
 63         Block.Parameter x = f.body().entryBlock().parameters().get(0);
 64         Block.Parameter y = f.body().entryBlock().parameters().get(1);
 65 
 66         CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x));
 67         dff_dx.writeTo(System.out);
 68         MethodHandle dff_dx_mh = generate(dff_dx);
 69         Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0));
 70         Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4));
 71 
 72         CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y));
 73         dff_dy.writeTo(System.out);
 74         MethodHandle dff_dy_mh = generate(dff_dy);
 75         Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0));
 76         Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4));
 77     }
 78 
 79     @CodeReflection
 80     static double f(double x, double y) {
 81         return x * (-Math.sin(x * y) + y) * 4.0d;
 82     }
 83 
 84     static double df_dx(double x, double y) {
 85         return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
 86     }
 87 
 88     static double df_dy(double x, double y) {
 89         return x * (1 - Math.cos(x * y) * x) * 4.0d;
 90     }
 91 
 92     @Test
 93     public void testControlFlow() throws Throwable {
 94         CoreOp.FuncOp f = getFuncOp("fcf");
 95         f.writeTo(System.out);
 96 
 97         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 98         f.writeTo(System.out);
 99 
100         f = SSA.transform(f);
101         f.writeTo(System.out);
102 
103         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 6), fcf(2.0, 6));
104         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 5), fcf(2.0, 5));
105         Assert.assertEquals(Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 4), fcf(2.0, 4));
106 
107         Block.Parameter x = f.body().entryBlock().parameters().get(0);
108 
109         CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x);
110         df_dx.writeTo(System.out);
111         MethodHandle df_dx_mh = generate(df_dx);
112 
113         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6));
114         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5));
115         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4));
116     }
117 
118     @CodeReflection
119     static double fcf(/* independent */ double x, int y) {
120         /* dependent */
121         double o = 1.0;
122         for (int i = 0; i < y; i = i + 1) {
123             if (i > 1) {
124                 if (i < 5) {
125                     o = o * x;
126                 }
127             }
128         }
129         return o;
130     }
131 
132     static double dfcf_dx(/* independent */ double x, int y) {
133         double d_o = 0.0;
134         double o = 1.0;
135         for (int i = 0; i < y; i = i + 1) {
136             if (i > 1) {
137                 if (i < 5) {
138                     d_o = d_o * x + o * 1.0;
139                     o = o * x;
140                 }
141             }
142         }
143         return d_o;
144     }
145 
146     static MethodHandle generate(CoreOp.FuncOp f) {
147         return BytecodeGenerator.generate(MethodHandles.lookup(), f);
148     }
149 
150     static CoreOp.FuncOp getFuncOp(String name) {
151         Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods())
152                 .filter(m -> m.getName().equals(name))
153                 .findFirst();
154 
155         Method m = om.get();
156         return Op.ofMethod(m).get();
157     }
158 }