1 import torch
2 import intel_extension_for_pytorch # type: ignore # noqa: F401
3
4 import triton
5 import triton.language as tl
6 import shutil
7 import os
8
9 HOME = os.environ['HOME']
10 BABYLON_PATH = os.path.join(HOME, 'babylon')
11
12 @triton.jit
13 def add_kernel():
14 pass
15
16 @triton.jit
17 def matmul_kernel():
18 pass
19
20 @triton.jit
21 def softmax_kernel():
22 pass
23
24 ADD_KERNEL_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/add_kernel.mlir"
25 MATMUL_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/matmul_kernel.mlir"
26 SOFTMAX_MLIR = f"{BABYLON_PATH}/cr-examples/triton/target/mlir/softmax_kernel.mlir"
27
28 if os.path.isdir(f'{HOME}/.triton/cache'):
29 shutil.rmtree(f'{HOME}/.triton/cache')
30
31 triton.compile(triton.compiler.ASTSource(fn=add_kernel, signature={}, constants={}), target_mlir=ADD_KERNEL_MLIR)
32 triton.compile(triton.compiler.ASTSource(fn=softmax_kernel, signature={}, constants={}), target_mlir=SOFTMAX_MLIR, options={"num_warps":32})
33 triton.compile(triton.compiler.ASTSource(fn=matmul_kernel, signature={}, constants={}), target_mlir=MATMUL_MLIR, options={"threads_per_warp":16, "num_warps":64})