1 /*
  2  * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 package oracle.code.triton;
 25 
 26 import org.junit.jupiter.api.Assertions;
 27 import org.junit.jupiter.api.extension.ExtensionContext;
 28 import org.junit.jupiter.api.extension.ParameterContext;
 29 import org.junit.jupiter.api.extension.ParameterResolver;
 30 
 31 import java.lang.annotation.ElementType;
 32 import java.lang.annotation.Retention;
 33 import java.lang.annotation.RetentionPolicy;
 34 import java.lang.annotation.Target;
 35 import java.lang.reflect.Method;
 36 import java.lang.reflect.code.TypeElement;
 37 import java.lang.reflect.code.op.CoreOp;
 38 import java.lang.reflect.code.parser.OpParser;
 39 import java.lang.reflect.code.type.JavaType;
 40 import java.lang.runtime.CodeReflection;
 41 import java.util.List;
 42 import java.util.Optional;
 43 import java.util.stream.Stream;
 44 
 45 public class TritonTestExtension implements ParameterResolver {
 46 
 47     @Target({ElementType.METHOD, ElementType.FIELD})
 48     @Retention(RetentionPolicy.RUNTIME)
 49     public @interface Kernel {
 50         String value();
 51     }
 52 
 53     @Override
 54     public boolean supportsParameter(ParameterContext pc, ExtensionContext ec) {
 55         return pc.getParameter().getType() == TritonTestData.class;
 56     }
 57 
 58     @Override
 59     public Object resolveParameter(ParameterContext pc, ExtensionContext ec) {
 60         Kernel k = ec.getRequiredTestMethod().getAnnotation(Kernel.class);
 61         String kernelName = (k != null)
 62             ? k.value()
 63             : ec.getRequiredTestMethod().getName();
 64 
 65         return new TritonTestData(ec.getRequiredTestClass(), kernelName);
 66     }
 67 
 68     public static class TritonTestData {
 69         final Class<?> testClass;
 70         final String javaKernelName;
 71 
 72         public TritonTestData(Class<?> testClass, String javaKernelName) {
 73             this.testClass = testClass;
 74             this.javaKernelName = javaKernelName;
 75         }
 76 
 77         public void test(List<? extends TypeElement> argTypes) {
 78             Optional<Method> om = Stream.of(testClass.getDeclaredMethods())
 79                     .filter(m -> m.getName().equals(javaKernelName))
 80                     .filter(m -> m.getAnnotation(CodeReflection.class) != null)
 81                     .findFirst();
 82             Method m = om.get();
 83             TritonCodeModel tcm = m.getAnnotation(TritonCodeModel.class);
 84             boolean doSSA = tcm != null ? tcm.SSA() : true;
 85             test(m.getCodeModel().get(), argTypes, expectedTritonKernel(tcm), doSSA);
 86         }
 87 
 88         public TritonOps.ModuleOp expectedTritonKernel(TritonCodeModel tcm) {
 89             if (tcm == null || tcm.value().isEmpty()) {
 90                 return null;
 91             }
 92 
 93             return (TritonOps.ModuleOp) OpParser.fromString(
 94                     TritonOps.FACTORY.andThen(ArithMathOps.FACTORY)
 95                             .andThen(TritonTestOps.FACTORY)
 96                             .andThen(SCFOps.FACTORY)
 97                             .andThen(CoreOp.FACTORY),
 98                     TritonOps.TYPE_FACTORY,
 99                     tcm.value()).get(0);
100         }
101 
102         void test(CoreOp.FuncOp javaKernel,
103                   List<? extends TypeElement> argTypes,
104                   TritonOps.ModuleOp expectedTritonKernel,
105                   boolean doSSA) {
106             TritonOps.ModuleOp actualTritonKernel = ScopedValue.getWhere(TritonTransformer.SV_SSA, doSSA,() -> {
107                 return TritonTransformer.tritonModule(javaKernel, JavaType.VOID, argTypes);
108             });
109 
110             Assertions.assertEquals(
111                     expectedTritonKernel == null ? "NO @TritonCodeModel" : expectedTritonKernel.toText(),
112                     actualTritonKernel.toText());
113         }
114     }
115 
116 }