1 /*
  2  * Copyright (c) 2024-2026, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package hat.codebuilders;
 26 
 27 import hat.buffer.Uniforms;
 28 import hat.types.F32;
 29 import hat.types.ivec2;
 30 import hat.types.mat2;
 31 import hat.types.mat3;
 32 import hat.types.vec2;
 33 import hat.types.vec3;
 34 import hat.types.vec4;
 35 import jdk.incubator.code.Op;
 36 import jdk.incubator.code.dialect.java.ClassType;
 37 import jdk.incubator.code.dialect.java.JavaType;
 38 import jdk.incubator.code.dialect.java.PrimitiveType;
 39 import optkl.IfaceValue;
 40 import optkl.OpHelper;
 41 import optkl.ifacemapper.MappableIface;
 42 
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.reflect.Field;
 45 import java.util.List;
 46 import java.util.function.Consumer;
 47 import java.util.stream.Stream;
 48 
 49 public class C99VecAndMatHandler {
 50 
 51     static final String SHADER_MAIN_IMAGE = "mainImage";
 52     static final String SHAPE_FIELD_NAME = "shape";
 53 
 54     static IfaceValue.vec.Shape getVecShape(java.lang.reflect.Type vecClass) {
 55         try {
 56             Field field = ((Class<?>) vecClass).getField(SHAPE_FIELD_NAME);
 57             return (IfaceValue.vec.Shape) field.get(null);
 58         } catch (Throwable t) {
 59             throw new RuntimeException(t);
 60         }
 61     }
 62 
 63     static IfaceValue.vec.Shape getVecShape(MethodHandles.Lookup lookup, JavaType javaType) {
 64         var resolved = OpHelper.classTypeToTypeOrThrow(lookup, (ClassType) javaType);
 65         return getVecShape(resolved);
 66 
 67     }
 68 
 69     static String clName(MethodHandles.Lookup lookup, JavaType javaType) {
 70 
 71         if (OpHelper.isAssignable(lookup, javaType, vec4.class)) {
 72             return "vec4";
 73         } else if (OpHelper.isAssignable(lookup, javaType, vec3.class)) {
 74             return "vec3";
 75         } else if (OpHelper.isAssignable(lookup, javaType, vec2.class)) {
 76             return "vec2";
 77         } else if (OpHelper.isAssignable(lookup, javaType, mat2.class)) {
 78             return "mat2";
 79         } else if (OpHelper.isAssignable(lookup, javaType, mat3.class)) {
 80             return "mat3";
 81         } else if (OpHelper.isAssignable(lookup, javaType, ivec2.class)) {
 82             return "ivec2";
 83         } else if (javaType.equals(PrimitiveType.FLOAT)) {
 84             return "float";
 85         } else if (javaType.equals(PrimitiveType.INT)) {
 86             return "int";
 87         } else {
 88             throw new RuntimeException("no cl name mapping for " + javaType);
 89         }
 90     }
 91 
 92     public static boolean isVecInvoke(OpHelper.Invoke invoke) {
 93         return (invoke.named(SHADER_MAIN_IMAGE) || invoke.refIs(F32.class, Uniforms.class, IfaceValue.vec.class, IfaceValue.mat.class));
 94     }
 95 
 96     public static String mangledName(OpHelper.Invoke invoke) {
 97         // So invoke  float mod(float lhs, float rhs) -> f32_mod_f32_f32"
 98         return clName(invoke.lookup(), (JavaType) invoke.returnType()) + "_"
 99                 + invoke.name() + "_"
100                 + clName(invoke.lookup(), (JavaType) invoke.resultFromOperandNOrNull(0).type()) + "_"
101                 + clName(invoke.lookup(), (JavaType) invoke.resultFromOperandNOrNull(1).type());
102     }
103 
104     public static <T extends C99HATKernelBuilder<T>> void args(C99HATKernelBuilder<T> bldr, Stream<Op.Result> argStream) {
105         bldr.paren(_ -> bldr.commaSpaceSeparated(argStream, operand -> bldr.recurse(operand.op())));
106     }
107 
108     public static <T extends C99HATKernelBuilder<T>> void nameAndArgs(String name, C99HATKernelBuilder<T> bldr, Stream<Op.Result> argStream) {
109         bldr.funcName(name).paren(_ -> args(bldr, argStream));
110     }
111 
112     public static <T extends C99HATKernelBuilder<T>> void nameAndArgs(OpHelper.Invoke invoke, String newName, C99HATKernelBuilder<T> bldr) {
113         bldr.funcName(newName).paren(_ ->
114                 bldr.commaSpaceSeparated(invoke.operandsAsResults(), operand -> bldr.recurse(operand.op()))
115         );
116     }
117 
118     public static <T extends C99HATKernelBuilder<T>> void nameAndArgs(OpHelper.Invoke invoke, C99HATKernelBuilder<T> bldr) {
119         nameAndArgs(invoke, invoke.name(), bldr);
120     }
121 
122     public static <T extends C99HATKernelBuilder<T>> void mangledNameAndArgs(OpHelper.Invoke invoke, C99HATKernelBuilder<T> bldr) {
123         nameAndArgs(invoke, mangledName(invoke), bldr);
124     }
125 
126     public static boolean hasVecOperand(OpHelper.Invoke invoke) {
127         return invoke.operandsAsResults().anyMatch(r ->
128                 r.type() instanceof ClassType classType
129                         && OpHelper.classTypeToTypeOrThrow(invoke.lookup(), classType) instanceof Class<?> clazz
130                         && IfaceValue.mat.class.isAssignableFrom(clazz));
131 
132     }
133 
134     public static <T extends C99HATKernelBuilder<T>> void handleF32Invoke(C99HATKernelBuilder<T> bldr, OpHelper.Invoke invoke) {
135         switch (invoke.name()) {
136             case "fract" -> //opencl 's fract wants a ptr as the second arg where it returns fract value.  So we implement our own
137                     bldr.paren(_ -> bldr.recurse(invoke.opFromFirstOperandOrNull()).sub().funcName("floor")
138                             .paren(_ -> bldr.recurse(invoke.opFromFirstOperandOrNull())));
139             case "atan" -> {
140                 if (invoke.operandCount() > 1) { // atan -> atan2
141                     nameAndArgs(invoke, invoke.name() + "2", bldr);  // atan(l,r) ->atan2(l,r)
142                 } else {
143                     nameAndArgs(invoke, bldr); // atan(v)->atan(v)
144                 }
145             }
146             case "inversesqrt" -> nameAndArgs(invoke, "rsqrt", bldr);// inversesqrt(...)->rsqrt(...)
147             case "cos", "sqrt", "sin", "exp", "pow", "min", "max", "log", "smoothstep", "clamp", "floor", "step",
148                  "mix" -> nameAndArgs(invoke, bldr); // asis!
149             case "abs" -> nameAndArgs(invoke, "fabs", bldr);// abs(...)->fabs(...)
150             case "mod" -> mangledNameAndArgs(invoke, bldr); //mod(...) -> float_mod_float_float(...)
151             default -> throw new RuntimeException("unmapped F32 call " + invoke.name());
152         }
153     }
154 
155     public static <T extends C99HATKernelBuilder<T>> void handleUniformsInvoke(C99HATKernelBuilder<T> bldr, OpHelper.Invoke invoke) {
156         bldr.cast(_ -> bldr.type((JavaType) invoke.returnType()));
157         switch (invoke.name()) {
158             case "iResolution" -> bldr.paren(_ ->
159                     bldr.sep(List.of("x", "y", "z"), _ -> bldr.csp(), lane -> bldr.id("uniforms").rarrow().id(invoke.name()).dot().id(lane))
160             );
161             case "iMouse" ->
162                     bldr.paren(_ -> bldr.sep(List.of("x", "y"), _ -> bldr.csp(), lane -> bldr.id("uniforms").rarrow().id(invoke.name()).dot().id(lane))
163                     );
164             case "iTime" -> bldr.id("uniforms").rarrow().id(invoke.name());
165             default -> throw new RuntimeException("some other uniform" + invoke.name());
166         }
167     }
168 
169     public static <T extends C99HATKernelBuilder<T>> void handleVecInvoke(C99HATKernelBuilder<T> bldr, OpHelper.Invoke invoke) {
170         if (invoke.refIs(IfaceValue.Struct.class)) {
171             // A call on a field of the uniforms
172             // so uniforms_t.iTime -> uniforms.iTime;
173             bldr.recurse(invoke.opFromFirstOperandOrNull()).dot().id(invoke.name());
174         } else if (invoke instanceof OpHelper.Invoke.Virtual && invoke.operandCount() == 1 && invoke.refIs(vec3.class, vec2.class, vec4.class)) {
175             // an accessor on a vec say v.x() -> v.x
176             bldr.recurse(invoke.opFromFirstOperandOrNull()).dot().id(invoke.name());
177         } else if (invoke.nameMatchesRegex("vec[234]")) { //  a psuedo vec2 constructor vec2(....) -> (vec2)(....)
178             bldr.paren(_ -> bldr.type(invoke.name()));
179             args(bldr, invoke.operandsAsResults());
180         } else if (invoke.nameMatchesRegex("(mod|reflect)")) {
181             mangledNameAndArgs(invoke, bldr);
182         } else if (invoke.named("mul") && invoke.operandCount() == 2 && hasVecOperand(invoke)) {
183             mangledNameAndArgs(invoke, bldr);
184         } else if (invoke.nameMatchesRegex("(mul|add|sub|div)")) {
185             // for opencl we can turn these into expressions. So vec3.mul(l,r) -> (l * r)
186             bldr.paren(_ -> bldr.recurse(invoke.opFromFirstOperandOrNull()).symbol(switch (invoke.name()) {
187                 case "mul" -> "*";
188                 case "add" -> "+";
189                 case "div" -> "/";
190                 case "sub" -> "-";
191                 default -> throw new IllegalStateException("oh my");
192             }).recurse(invoke.opFromOperandNOrNull(1)));
193         } else if (invoke.named("abs") && invoke.refIs(IfaceValue.vec.class)) {
194             bldr.funcName("f" + invoke.name()).paren(_ ->
195                     bldr.sep(invoke.op().operands(), _ -> bldr.csp(), v -> bldr.recurse(v.asResult().op()))
196             );
197         } else if (invoke.named("fract") && invoke.operandCount() == 1) {
198             // return x - floor(x);
199             bldr.paren(_ -> bldr.recurse(invoke.opFromFirstOperandOrNull()).sub().funcName("floor")
200                     .paren(_ -> bldr.recurse(invoke.opFromFirstOperandOrNull())));
201         } else if (invoke.nameMatchesRegex("(dot|length|max|mix|min|smoothstep|step|normalize|clamp|pow|cross|distance|floor|fract|round|sin|cos|abs)")) {
202             bldr.funcName(invoke.op()).paren(_ ->
203                     bldr.sep(invoke.op().operands(), _ -> bldr.csp(), v -> bldr.recurse(v.asResult().op()))
204             );
205         } else {
206             StringBuilder stringBuilder = new StringBuilder("For vec types we need to IMPLEMENT " + invoke.refType() + ":" + invoke.name() + "(");
207             invoke.op().operands().forEach(o -> stringBuilder.append(" " + o.asResult().type()));
208             stringBuilder.append(")");
209             throw new RuntimeException(stringBuilder.toString());
210         }
211     }
212 
213     public static <T extends C99HATKernelBuilder<T>> void handleMatInvoke(C99HATKernelBuilder<T> bldr, OpHelper.Invoke invoke) {
214         if (invoke.nameMatchesRegex("mat[234]")) {
215             bldr.paren(_ -> bldr.type(invoke.name())).brace(_ ->
216                     bldr.commaSpaceSeparated(invoke.operandsAsResults(), operand -> bldr.recurse(operand.op()))
217             );
218         } else if (invoke.named("mul")) {
219             mangledNameAndArgs(invoke, bldr);
220         } else {
221             StringBuilder stringBuilder = new StringBuilder("For mat types we need to IMPLEMENT " + invoke.refType() + ":" + invoke.name() + "(");
222             invoke.op().operands().forEach(o -> stringBuilder.append(" " + o.asResult().type()));
223             stringBuilder.append(")");
224             throw new RuntimeException(stringBuilder.toString());
225         }
226     }
227 
228     public static <T extends C99HATKernelBuilder<T>> void handleInvoke(C99HATKernelBuilder<T> bldr, OpHelper.Invoke invoke) {
229         if (invoke.named(SHADER_MAIN_IMAGE)) {
230             nameAndArgs(invoke, bldr);
231         } else if (invoke.refIs(F32.class)) {
232             handleF32Invoke(bldr, invoke);
233         } else if (invoke.refIs(Uniforms.class)) {
234             handleUniformsInvoke(bldr, invoke);
235         } else if (invoke.refIs(IfaceValue.vec.class)) {
236             handleVecInvoke(bldr, invoke);
237         } else if (invoke.refIs(IfaceValue.mat.class)) {
238             handleMatInvoke(bldr, invoke);
239         }
240     }
241 
242     public static boolean isVecOrMatType(MethodHandles.Lookup lookup, JavaType javaType) {
243         return OpHelper.isAssignable(lookup, javaType, MappableIface.vec.class, IfaceValue.mat.class);
244     }
245 
246     public static <T extends C99HATKernelBuilder<T>> void handleType(C99HATKernelBuilder<T> bldr, JavaType javaType) {
247         bldr.type(clName(bldr.scopedCodeBuilderContext().lookup(), javaType));
248     }
249 
250     public static <T extends C99HATKernelBuilder<T>> void genFunc(C99HATKernelBuilder<T> bldr, String ret, String op, String lhs, String rhs, Consumer<C99HATKernelBuilder<T>> consumer) {
251         bldr.func(
252                 _ -> bldr.type(ret),
253                 ret + "_" + op + "_" + lhs + "_" + rhs,
254                 _ -> bldr.type(lhs).sp().id("l").csp().type(rhs).sp().id("r"),
255                 _ -> consumer.accept(bldr)
256         ).semicolon();
257 
258     }
259 
260     public static <T extends C99HATKernelBuilder<T>> void createVecFunctions(C99HATKernelBuilder<T> builder) {
261         record NamedMatShape(String name, IfaceValue.mat.Shape shape) {
262         }
263         List.of(new NamedMatShape("mat2", mat2.shape), new NamedMatShape("mat3", mat3.shape)).forEach(ns ->
264                 builder.typedefKeyword().sp().structKeyword().sp().id(ns.name + "_s").braceNlIndented(_ -> {
265                     builder.sep(ns.shape.rowColNames(), _ -> builder.nl(), n ->
266                             builder.type((JavaType) ns.shape.codeType()).sp().id(n).semicolon()
267                     );
268                 }).sp().id(ns.name).snl()
269         );
270 
271 
272         record NamedVecShape(String name, IfaceValue.vec.Shape shape) {
273         }
274 
275         List.of(new NamedVecShape("vec2", vec2.shape), new NamedVecShape("vec3", vec3.shape), new NamedVecShape("vec4", vec4.shape)).forEach(ns ->
276                 builder.typedefKeyword().sp().type("float" + ns.shape.lanes()).sp().type(ns.name).snl()
277         );
278         List.of(new NamedVecShape("ivec2", vec2.shape)).forEach(ns ->
279                 builder.typedefKeyword().sp().type("int" + ns.shape.lanes()).sp().type(ns.name).snl()
280         );
281 /*
282 2. Vector * Matrix (vec2 * mat2)This treats the vector as a row. Mathematically, this is equivalent to multiplying the transpose of the matrix by the vector.$$\text{result}.x = (v.x \cdot m_{0}) + (v.y \cdot m_{1})$$$$\text{result}.y = (v.x \cdot m_{2}) + (v.y \cdot m_{3})$$Javapublic static float[] multiplyVec2Mat2(float[] v, float[] m) {
283     float x = v[0] * m[0] + v[1] * m[1];
284     float y = v[0] * m[2] + v[1] * m[3];
285      l.x * r.00 + l.y * r.01,l.x * r.10 + l.y * r.11;
286     return new float[]{x, y};
287 }
288         */
289         genFunc(builder, "vec2", "mul", "vec2", "mat2", _ ->
290                 builder.returnKeyword().sp().paren(_ -> builder.type("vec2"))
291                         .paren(_ -> builder.preformatted("l.x*r._00+l.y*r._01,l.x*r._10+l.y*r._11"))
292                         .semicolon()
293         );
294 
295         /*
296         public static float[] multiplyMat2Vec2(float[] m, float[] v) {
297     float x = m[0] * v[0] + m[2] * v[1];
298     float y = m[1] * v[0] + m[3] * v[1];
299      l.00 * r.x + l.10 * r.y,
300      l.01 * r.x + l.11 * r.y
301     return new float[]{x, y};
302 }
303 
304 
305          */
306         genFunc(builder, "vec2", "mul", "mat2", "vec2", _ ->
307                 builder.returnKeyword().sp().paren(_ -> builder.type("vec2")).paren(_ ->
308                         builder.preformatted(" l._00*r.x+l._10*r.y,l._01*r.x+l._11*r.y")
309                 ).semicolon());
310 
311         /*
312           public static float[] multiplyVec3Mat3(float[] v, float[] m) {
313             float x = v[0] * m[0] + v[1] * m[1] + v[2] * m[2];
314             float y = v[0] * m[3] + v[1] * m[4] + v[2] * m[5];
315             float z = v[0] * m[6] + v[1] * m[7] + v[2] * m[8];
316            l.x * r._00 + l.y * r._01 + l.z * r._02,
317            l.x * r._10 + l.y * r._11 + l.z * r._12,
318            l.x * r._20 + l.y * r._21 + l.z * r._22,
319             return new float[]{x, y, z};
320         }
321          */
322         genFunc(builder, "vec3", "mul", "vec3", "mat3", _ ->
323                 builder.returnKeyword().sp().paren(_ -> builder.type("vec3")).paren(_ ->
324                 builder.preformatted("l.x*r._00+l.y*r._01+l.z*r._02,l.x*r._10+l.y*r._11+l.z*r._12,l.x*r._20+l.y*r._21+l.z*r._22")).semicolon()
325         );
326         genFunc(builder, "vec3", "mul", "mat3", "vec3", _ ->
327                 builder.returnKeyword().sp().paren(_ -> builder.type("vec3")).paren(_ ->
328                         builder.preformatted(
329                                 "       l._00 * r.x + l._01 * r.y + l._02 * r.z," +
330                                 "        l._10 * r.x + l._11 * r.y + l._12 * r.z," +
331                                 "        l._20 * r.x + l._21 * r.y + l._22 * r.z")).semicolon()
332         );
333 
334         /*
335           l._00() * r.x() + l._01() * r.y() + l._02() * r.z(),
336         l._10() * r.x() + l._11() * r.y() + l._12() * r.z(),
337         l._20() * r.x() + l._21() * r.y() + l._22() * r.z()
338          */
339 
340 
341         /*
342         public static float[] multiplyMat3Vec3(float[] m, float[] v) {
343           //  float x = m[0] * v[0] + m[3] * v[1] + m[6] * v[2];
344           //  float y = m[1] * v[0] + m[4] * v[1] + m[7] * v[2];
345           //  float z = m[2] * v[0] + m[5] * v[1] + m[8] * v[2];
346             float x = m[0] * v[0] + m[3] * v[1] + m[6] * v[2];
347             float y = m[1] * v[0] + m[4] * v[1] + m[7] * v[2];
348             float z = m[2] * v[0] + m[5] * v[1] + m[8] * v[2];
349             return new float[]{x, y, z};
350         }
351         2. Vector * Matrix (vec3 * mat3)
352         This treats the vector as a row vector on the left. Effectively, you are calculating the dot product of the vector with each column of the matrix.
353 
354         Java
355         public static float[] multiplyVec3Mat3(float[] v, float[] m) {
356             float x = v[0] * m[0] + v[1] * m[1] + v[2] * m[2];
357             float y = v[0] * m[3] + v[1] * m[4] + v[2] * m[5];
358             float z = v[0] * m[6] + v[1] * m[7] + v[2] * m[8];
359             return new float[]{x, y, z};
360         }
361 */
362         genFunc(builder, "mat3", "mul", "mat3", "mat3", _ ->
363                 builder.returnKeyword().sp().paren(_ -> builder.type("mat3")).brace(_ ->
364                         builder.preformatted("""
365                                     l._00*r._00,l._01*r._01,l._02*r._02,
366                                     l._10*r._10,l._11*r._11,l._12*r._12,
367                                     l._20*r._20,l._21*r._21,l._22*r._22
368                                 """)
369                 ).semicolon()
370         );
371         genFunc(builder, "mat2", "mul", "mat2", "mat2", _ ->
372                 builder.returnKeyword().sp().paren(_ -> builder.type("mat2")).brace(_ ->
373                         builder.preformatted("l._00*r._00,l._01*r._01, l._10*r._10,l._11*r._11")
374                 ).semicolon()
375         );
376 
377         genFunc(builder, "float", "mod", "float", "float", _ ->
378                 builder.returnKeyword().sp().id("l").sp().minus().id("r").sp().mul().sp().id("floor").paren(_ ->
379                         builder.id("l").div().id("r")).semicolon());
380 
381         genFunc(builder, "vec2", "mod", "vec2", "float", _ ->
382                 builder.returnKeyword().sp().id("l").sp().minus().id("r").sp().mul().sp().id("floor").paren(_ ->
383                         builder.id("l").div().id("r")).semicolon());
384 
385         genFunc(builder, "vec3", "reflect", "vec3", "vec3", _ ->
386                 builder.returnKeyword().sp().id("l").sp().minus().id("r").sp().mul().sp().id("l").sp().mul().floatConst(2).semicolon());
387     }
388 }