1 /*
2 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation. Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26 package experiments;
27
28
29 import optkl.VarTable;
30 import optkl.codebuilders.JavaCodeBuilder;
31 import optkl.Trxfmr;
32 import static optkl.OpHelper.Invoke;
33 import static optkl.OpHelper.Invoke.invoke;
34 import jdk.incubator.code.Op;
35 import jdk.incubator.code.bytecode.BytecodeGenerator;
36 import jdk.incubator.code.dialect.core.CoreOp;
37 import jdk.incubator.code.dialect.core.CoreType;
38 import jdk.incubator.code.dialect.java.JavaOp;
39 import jdk.incubator.code.dialect.java.JavaOp.InvokeOp.InvokeKind;
40 import jdk.incubator.code.dialect.java.JavaType;
41 import jdk.incubator.code.dialect.java.MethodRef;
42 import optkl.util.Regex;
43
44 import java.lang.invoke.MethodHandles;
45
46 public class SwapMath {
47 public static void main(String[] args) throws Throwable {
48 var lookup = MethodHandles.lookup();
49 MethodRef MathSqrt = MethodRef.method(Math.class, "sqrt", double.class, double.class);
50 MethodRef MathAbs = MethodRef.method(Math.class, "abs", double.class, double.class);
51
52 CoreOp.FuncOp rsqrt= CoreOp.func("rsqrt", CoreType.functionType(JavaType.DOUBLE, JavaType.DOUBLE))
53 .body(builder -> {// double rsqrt(double arg){return 1 / Math.sqrt(qrg)}
54 // var arg = builder.parameters().getFirst();
55 var argOp = CoreOp.var("arg", builder.parameters().getFirst());
56 var arg = builder.add(argOp);
57
58 // We can pass builder.parameters().getFirst() directly as arg below. But then we don't know the name
59 var sqrtInvoke = JavaOp.invoke(InvokeKind.STATIC, false, JavaType.DOUBLE, MathSqrt, arg);
60 var _1f = builder.add(CoreOp.constant(JavaType.DOUBLE, 1.0));
61
62 Op.Result invokeResult = builder.add(sqrtInvoke);
63 Op.Result divResult = builder.add(
64 JavaOp.div(_1f, invokeResult)
65 );
66 builder.add(CoreOp.return_(divResult));
67 });
68 var javaCodeBuilder = new JavaCodeBuilder<>(lookup,rsqrt);
69 System.out.println(rsqrt.toText());
70 System.out.println(javaCodeBuilder.toText());
71 System.out.println(" 1/sqrt(100) = " + BytecodeGenerator.generate(lookup, rsqrt).invoke(100));
72
73
74
75 System.out.println("--------------------------");
76 var abs = rsqrt.transform("usingAbs", (builder,op)->{
77 if (invoke(lookup,op) instanceof Invoke.Static ih
78 && ih.named("sqrt") && ih.returns(double.class) && ih.receives(double.class)){
79 var absStaticMethod = MethodRef.method(Math.class, "abs", double.class, double.class);
80 var absInvoke = JavaOp.invoke(InvokeKind.STATIC, false, absStaticMethod.signature().returnType(), absStaticMethod,
81 builder.context().getValue(op.operands().get(0)));
82 var absResult= builder.add(absInvoke);
83 builder.context().mapValue(op.result(), absResult);
84 }else{
85 builder.add(op);
86 }
87 return builder;
88 });
89
90 System.out.println(abs.toText());
91 javaCodeBuilder = new JavaCodeBuilder<>(lookup,abs);
92 System.out.println(" 1/abs(100) = " + BytecodeGenerator.generate(MethodHandles.lookup(), abs).invoke(100));
93
94
95 System.out.println("Now using txfmr--------------------------");
96 VarTable varTable = new VarTable(rsqrt.funcName());
97 var newAbs =Trxfmr.of(lookup, rsqrt)
98 .transform("usingAbs", varTable, ce-> invoke(lookup,ce) instanceof Invoke.Static $
99 && $.named("sqrt")
100 && $.returns(double.class)
101 && $.receives(double.class)
102 , c->
103 c.replace(JavaOp.invoke(InvokeKind.STATIC, false, JavaType.DOUBLE, MathAbs, c.mappedOperand( 0)))
104 )
105 .funcOp();
106
107
108 System.out.println(newAbs.toText());
109 javaCodeBuilder = new JavaCodeBuilder<>(lookup,newAbs);
110 System.out.println(javaCodeBuilder.toText());
111 System.out.println(" 1/abs(100) = " + BytecodeGenerator.generate(MethodHandles.lookup(), newAbs).invoke(100));
112
113
114 }
115 }
116