1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import jdk.incubator.code.Block;
25 import jdk.incubator.code.CodeReflection;
26 import jdk.incubator.code.Op;
27 import jdk.incubator.code.OpTransformer;
28 import jdk.incubator.code.analysis.SSA;
29 import jdk.incubator.code.bytecode.BytecodeGenerator;
30 import jdk.incubator.code.dialect.core.CoreOp;
31 import jdk.incubator.code.interpreter.Interpreter;
32 import org.junit.jupiter.api.Assertions;
33 import org.junit.jupiter.api.Test;
34
35 import java.lang.invoke.MethodHandle;
36 import java.lang.invoke.MethodHandles;
37 import java.lang.reflect.Method;
38 import java.util.Optional;
39 import java.util.stream.Stream;
40
41 /*
42 * @test
43 * @modules jdk.incubator.code
44 * @enablePreview
45 * @run junit TestForwardAutoDiff
46 * @run junit/othervm -Dbabylon.ssa=cytron TestForwardAutoDiff
47 */
48
49 public class TestForwardAutoDiff {
50 static final double PI_4 = Math.PI / 4;
51
52 @Test
53 public void testExpression() throws Throwable {
54 CoreOp.FuncOp f = getFuncOp("f");
55 System.out.println(f.toText());
56
57 f = SSA.transform(f);
58 System.out.println(f.toText());
59
60 Assertions.assertEquals(f(0.0, 1.0), Interpreter.invoke(MethodHandles.lookup(), f, 0.0, 1.0));
61 Assertions.assertEquals(f(PI_4, PI_4), Interpreter.invoke(MethodHandles.lookup(), f, PI_4, PI_4));
62
63 Block.Parameter x = f.body().entryBlock().parameters().get(0);
64 Block.Parameter y = f.body().entryBlock().parameters().get(1);
65
66 CoreOp.FuncOp dff_dx = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, x));
67 System.out.println(dff_dx.toText());
68 MethodHandle dff_dx_mh = generate(dff_dx);
69 Assertions.assertEquals(df_dx(0.0, 1.0), (double) dff_dx_mh.invoke(0.0, 1.0));
70 Assertions.assertEquals(df_dx(PI_4, PI_4), (double) dff_dx_mh.invoke(PI_4, PI_4));
71
72 CoreOp.FuncOp dff_dy = ExpressionElimination.eliminate(ForwardDifferentiation.partialDiff(f, y));
73 System.out.println(dff_dy.toText());
74 MethodHandle dff_dy_mh = generate(dff_dy);
75 Assertions.assertEquals(df_dy(0.0, 1.0), (double) dff_dy_mh.invoke(0.0, 1.0));
76 Assertions.assertEquals(df_dy(PI_4, PI_4), (double) dff_dy_mh.invoke(PI_4, PI_4));
77 }
78
79 @CodeReflection
80 static double f(double x, double y) {
81 return x * (-Math.sin(x * y) + y) * 4.0d;
82 }
83
84 static double df_dx(double x, double y) {
85 return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
86 }
87
88 static double df_dy(double x, double y) {
89 return x * (1 - Math.cos(x * y) * x) * 4.0d;
90 }
91
92 @Test
93 public void testControlFlow() throws Throwable {
94 CoreOp.FuncOp f = getFuncOp("fcf");
95 System.out.println(f.toText());
96
97 f = f.transform(OpTransformer.LOWERING_TRANSFORMER);
98 System.out.println(f.toText());
99
100 f = SSA.transform(f);
101 System.out.println(f.toText());
102
103 Assertions.assertEquals(fcf(2.0, 6), Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 6));
104 Assertions.assertEquals(fcf(2.0, 5), Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 5));
105 Assertions.assertEquals(fcf(2.0, 4), Interpreter.invoke(MethodHandles.lookup(), f, 2.0, 4));
106
107 Block.Parameter x = f.body().entryBlock().parameters().get(0);
108
109 CoreOp.FuncOp df_dx = ForwardDifferentiation.partialDiff(f, x);
110 System.out.println(df_dx.toText());
111 MethodHandle df_dx_mh = generate(df_dx);
112
113 Assertions.assertEquals(dfcf_dx(2.0, 6), (double) df_dx_mh.invoke(2.0, 6));
114 Assertions.assertEquals(dfcf_dx(2.0, 5), (double) df_dx_mh.invoke(2.0, 5));
115 Assertions.assertEquals(dfcf_dx(2.0, 4), (double) df_dx_mh.invoke(2.0, 4));
116 }
117
118 @CodeReflection
119 static double fcf(/* independent */ double x, int y) {
120 /* dependent */
121 double o = 1.0;
122 for (int i = 0; i < y; i = i + 1) {
123 if (i > 1) {
124 if (i < 5) {
125 o = o * x;
126 }
127 }
128 }
129 return o;
130 }
131
132 static double dfcf_dx(/* independent */ double x, int y) {
133 double d_o = 0.0;
134 double o = 1.0;
135 for (int i = 0; i < y; i = i + 1) {
136 if (i > 1) {
137 if (i < 5) {
138 d_o = d_o * x + o * 1.0;
139 o = o * x;
140 }
141 }
142 }
143 return d_o;
144 }
145
146 static MethodHandle generate(CoreOp.FuncOp f) {
147 return BytecodeGenerator.generate(MethodHandles.lookup(), f);
148 }
149
150 static CoreOp.FuncOp getFuncOp(String name) {
151 Optional<Method> om = Stream.of(TestForwardAutoDiff.class.getDeclaredMethods())
152 .filter(m -> m.getName().equals(name))
153 .findFirst();
154
155 Method m = om.get();
156 return Op.ofMethod(m).get();
157 }
158 }