1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 import org.testng.Assert;
 25 import org.testng.annotations.Test;
 26 
 27 import java.lang.reflect.code.Block;
 28 import java.lang.reflect.code.OpTransformer;
 29 import java.lang.reflect.code.op.CoreOp;
 30 import java.lang.reflect.code.Op;
 31 import java.lang.reflect.code.analysis.SSA;
 32 import java.lang.reflect.code.bytecode.BytecodeGenerator;
 33 import java.lang.reflect.code.interpreter.Interpreter;
 34 import java.lang.invoke.MethodHandle;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.reflect.Method;
 37 import java.lang.runtime.CodeReflection;
 38 import java.util.Optional;
 39 import java.util.stream.Stream;
 40 
 41 /*
 42  * @test
 43  * @enablePreview
 44  * @run testng TestForwardAutoDiff
 45  */
 46 
 47 public class TestForwardAutoDiff {
 48     static final double PI_4 = Math.PI / 4;
 49 
 50     @Test
 51     public void testExpression() throws Throwable {
 52         CoreOp.FuncOp f = getFuncOp("f");
 53         f.writeTo(System.out);
 54 
 55         f = SSA.transform(f);
 56         f.writeTo(System.out);
 57 
 58         Assert.assertEquals(Interpreter.invoke(f, 0.0, 1.0), f(0.0, 1.0));
 59         Assert.assertEquals(Interpreter.invoke(f, PI_4, PI_4), f(PI_4, PI_4));
 60 
 61         Block.Parameter x = f.body().entryBlock().parameters().get(0);
 62         Block.Parameter y = f.body().entryBlock().parameters().get(1);
 63 
 64         CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x));
 65         dff_dx.writeTo(System.out);
 66         MethodHandle dff_dx_mh = generate(dff_dx);
 67         Assert.assertEquals((double) dff_dx_mh.invoke(0.0, 1.0), df_dx(0.0, 1.0));
 68         Assert.assertEquals((double) dff_dx_mh.invoke(PI_4, PI_4), df_dx(PI_4, PI_4));
 69 
 70         CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y));
 71         dff_dy.writeTo(System.out);
 72         MethodHandle dff_dy_mh = generate(dff_dy);
 73         Assert.assertEquals((double) dff_dy_mh.invoke(0.0, 1.0), df_dy(0.0, 1.0));
 74         Assert.assertEquals((double) dff_dy_mh.invoke(PI_4, PI_4), df_dy(PI_4, PI_4));
 75     }
 76 
 77     @CodeReflection
 78     static double f(double x, double y) {
 79         return x * (-Math.sin(x * y) + y) * 4.0d;
 80     }
 81 
 82     static double df_dx(double x, double y) {
 83         return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
 84     }
 85 
 86     static double df_dy(double x, double y) {
 87         return x * (1 - Math.cos(x * y) * x) * 4.0d;
 88     }
 89 
 90     @Test
 91     public void testControlFlow() throws Throwable {
 92         CoreOp.FuncOp f = getFuncOp("fcf");
 93         f.writeTo(System.out);
 94 
 95         f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
 96         f.writeTo(System.out);
 97 
 98         f = SSA.transform(f);
 99         f.writeTo(System.out);
100 
101         Assert.assertEquals(Interpreter.invoke(f, 2.0, 6), fcf(2.0, 6));
102         Assert.assertEquals(Interpreter.invoke(f, 2.0, 5), fcf(2.0, 5));
103         Assert.assertEquals(Interpreter.invoke(f, 2.0, 4), fcf(2.0, 4));
104 
105         Block.Parameter x = f.body().entryBlock().parameters().get(0);
106 
107         CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x);
108         df_dx.writeTo(System.out);
109         MethodHandle df_dx_mh = generate(df_dx);
110 
111         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 6), dfcf_dx(2.0, 6));
112         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 5), dfcf_dx(2.0, 5));
113         Assert.assertEquals((double) df_dx_mh.invoke(2.0, 4), dfcf_dx(2.0, 4));
114     }
115 
116     @CodeReflection
117     static double fcf(/* independent */ double x, int y) {
118         /* dependent */
119         double o = 1.0;
120         for (int i = 0; i < y; i = i + 1) {
121             if (i > 1) {
122                 if (i < 5) {
123                     o = o * x;
124                 }
125             }
126         }
127         return o;
128     }
129 
130     static double dfcf_dx(/* independent */ double x, int y) {
131         double d_o = 0.0;
132         double o = 1.0;
133         for (int i = 0; i < y; i = i + 1) {
134             if (i > 1) {
135                 if (i < 5) {
136                     d_o = d_o * x + o * 1.0;
137                     o = o * x;
138                 }
139             }
140         }
141         return d_o;
142     }
143 
144     static MethodHandle generate(CoreOp.FuncOp f) {
145         return BytecodeGenerator.generate(MethodHandles.lookup(), f);
146     }
147 
148     static CoreOp.FuncOp getFuncOp(String name) {
149         Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods())
150                 .filter(m -> m.getName().equals(name))
151                 .findFirst();
152 
153         Method m = om.get();
154         return m.getCodeModel().get();
155     }
156 }