From ac7497b2c86bdcef645a2b6dc64d77c2eddec943 Mon Sep 17 00:00:00 2001 From: Jinjie Liu Date: Sat, 31 Jan 2026 10:36:41 +0800 Subject: [PATCH] enable cjit launcher Signed-off-by: Jinjie Liu --- include/type.h | 1 + pyproject.toml | 1 + python/triton_tvm_ffi/_ffi_api.py | 2 + python/triton_tvm_ffi/driver.py | 133 ++++++++++++++------ python/triton_tvm_ffi/templates/launch.c.j2 | 65 ++++++++++ src/type.cc | 15 ++- src/utils.cc | 1 - 7 files changed, 180 insertions(+), 38 deletions(-) create mode 100644 python/triton_tvm_ffi/templates/launch.c.j2 diff --git a/include/type.h b/include/type.h index 68d10cf..da70716 100644 --- a/include/type.h +++ b/include/type.h @@ -38,6 +38,7 @@ enum class Type : int64_t { const char *TypeToString(Type type); tvm::ffi::Optional StringToType(const tvm::ffi::String &name); +const char *TypeToCType(Type type); template struct type_to_ctype; #define DEFINE_TYPE_TO_CTYPE(type, str, ctype) \ diff --git a/pyproject.toml b/pyproject.toml index 9b51f9f..451a3cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" dependencies = [ "apache-tvm-ffi", + "jinja2", ] [build-system] diff --git a/python/triton_tvm_ffi/_ffi_api.py b/python/triton_tvm_ffi/_ffi_api.py index 344a060..a872302 100644 --- a/python/triton_tvm_ffi/_ffi_api.py +++ b/python/triton_tvm_ffi/_ffi_api.py @@ -20,6 +20,7 @@ LIB = _FFI_LOAD_LIB("triton_tvm_ffi", "triton_tvm_ffi") _FFI_INIT_FUNC("triton_tvm_ffi", __name__) if TYPE_CHECKING: def string_to_type(_0: str, /) -> int | None: ... + def type_to_ctype(_0: int, /) -> str: ... def type_to_string(_0: int, /) -> str: ... # fmt: on # tvm-ffi-stubgen(end) @@ -46,6 +47,7 @@ __all__ = [ "LIB", "TVMFFILauncherImpl", "string_to_type", + "type_to_ctype", "type_to_string", # tvm-ffi-stubgen(end) ] diff --git a/python/triton_tvm_ffi/driver.py b/python/triton_tvm_ffi/driver.py index 1619624..c34eaab 100644 --- a/python/triton_tvm_ffi/driver.py +++ b/python/triton_tvm_ffi/driver.py @@ -1,9 +1,15 @@ from __future__ import annotations -from typing import Any, Callable, Final, List, Sequence, Type, Union +from functools import cached_property +import os +from typing import Any, Callable, Final, List, Sequence, Type + +import jinja2 from triton.backends.nvidia.driver import CudaDriver from triton.runtime import _allocation -from . import TVMFFILauncherImpl, utils, string_to_type +import tvm_ffi + +from . import TVMFFILauncherImpl, utils, string_to_type, type_to_ctype class TVMLauncher(object): @@ -18,27 +24,53 @@ class TVMLauncher(object): self.profile_scratch_align: Final[int] = metadata.profile_scratch_align self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid self.launch_pdl: Final[bool] = metadata.launch_pdl - self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( - [string_to_type(t) for t in self.signature], - self.launch_cooperative_grid, - self.launch_pdl, + self.enable_jit: Final[bool] = ( + os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None ) - self.launch: Callable[ - [ - int, - int, - int, - int, - int, - tuple[int, int, int], - object, - object, - object, - object, - object, - Sequence[Union[Any]], - ] - ] = self.impl.launch + if self.enable_jit: + mod = tvm_ffi.cpp.load_inline( + "launch", + cpp_sources=self.codegen, + extra_ldflags=["-Wl,--no-as-needed", "-lcuda"], + extra_include_paths=[ + f"{tvm_ffi.cpp.extension._find_cuda_home()}/include" + ], + ) + launch = mod.get_function("launch") + self.launch = launch + else: + self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl( + [string_to_type(t) for t in self.signature], + self.launch_cooperative_grid, + self.launch_pdl, + ) + self.launch = ( + lambda grid_x, + grid_y, + grid_z, + stream, + function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + global_scratch, + profile_scratch, + *args: self.impl.launch( + grid_x, + grid_y, + grid_z, + stream, + function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + global_scratch, + profile_scratch, + args, + ) + ) def __call__( self, @@ -72,22 +104,51 @@ class TVMLauncher(object): assert not self.launch_cooperative_grid assert not self.launch_pdl - # args: Sequence[TypedValue] = TypedValue.make_typed_values(self.signature, args) + if self.enable_jit: + (num_warps, num_ctas, shared_memory) = kernel_metadata + return self.launch( + gridX, + gridY, + gridZ, + stream, + function, + num_warps, + num_ctas, + shared_memory, + *args, + ) + else: + return self.launch( + gridX, + gridY, + gridZ, + stream, + function, + kernel_metadata, + launch_metadata, + launch_enter_hook, + launch_exit_hook, + global_scratch, + profile_scratch, + *args, + ) - return self.impl.launch( - gridX, - gridY, - gridZ, - stream, - function, - kernel_metadata, - launch_metadata, - launch_enter_hook, - launch_exit_hook, - global_scratch, - profile_scratch, - args, + @cached_property + def codegen(self) -> str: + env: Final[jinja2.Environment] = jinja2.Environment( + loader=jinja2.PackageLoader("triton_tvm_ffi", "templates"), + trim_blocks=True, + lstrip_blocks=True, ) + template = env.get_template("launch.c.j2") + signature = list( + filter( + lambda t: t != "void", + map(lambda t: type_to_ctype(string_to_type(t)), self.signature), + ) + ) + html = template.render(signature=signature) + return html class TVMFFIDriver(CudaDriver): diff --git a/python/triton_tvm_ffi/templates/launch.c.j2 b/python/triton_tvm_ffi/templates/launch.c.j2 new file mode 100644 index 0000000..94209ca --- /dev/null +++ b/python/triton_tvm_ffi/templates/launch.c.j2 @@ -0,0 +1,65 @@ +#include +#include +#include + +#ifdef __cplusplus +extern "C" +#endif +TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) { + int32_t gridX = args[0].v_int64; + int32_t gridY = args[1].v_int64; + int32_t gridZ = args[2].v_int64; + CUstream stream = (CUstream)args[3].v_uint64; + CUfunction function = (CUfunction)args[4].v_uint64; + int32_t numWarps = args[5].v_int64; + int32_t numCtas = args[6].v_int64; + int32_t sharedMemory = args[7].v_int64; + // TODO: Implement the launch logic + CUdeviceptr globalScratch = 0; + // TODO: check `profileScratchObject` + CUdeviceptr profileScratch = 0; + if (gridX * gridY * gridZ > 0) { + CUlaunchAttribute launchAttr[4]; + CUlaunchConfig config; + config.gridDimX = gridX * numCtas; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + static constexpr int32_t kThreadsPerWarp = 32; + config.blockDimX = kThreadsPerWarp * numWarps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = sharedMemory; + config.hStream = stream; + config.attrs = launchAttr; + int32_t numAttrs = 0; + // TODO: check `launchPdl` + // TODO: check `launchCooperativeGrid` + if (numCtas != 1) { + CUlaunchAttribute clusterAttr; + clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + clusterAttr.value.clusterDim.x = numCtas; + clusterAttr.value.clusterDim.y = 1; + clusterAttr.value.clusterDim.z = 1; + launchAttr[numAttrs++] = clusterAttr; + CUlaunchAttribute clusterSchedulingAttr; + clusterSchedulingAttr.id = + CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + launchAttr[numAttrs++] = clusterSchedulingAttr; + } + config.numAttrs = numAttrs; +{% for type in signature %} + {% if type == "void *" %} + {{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data; + {% elif type == "int32_t" %} + {{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64; + {% else %} + assert(false, "unsupported type yet {{ type }}"); + {% endif %} +{% endfor %} + void *foo = NULL, *bar = NULL; + void *params[] = { {% for type in signature %} &arg{{ loop.index0 }}, {% endfor %}&foo, &bar }; + cuLaunchKernelEx(&config, function, params, NULL); + } +} diff --git a/src/type.cc b/src/type.cc index c5762ce..279588d 100644 --- a/src/type.cc +++ b/src/type.cc @@ -38,11 +38,24 @@ tvm::ffi::Optional StringToType(const tvm::ffi::String &name) { return std::nullopt; } +const char *TypeToCType(Type type) { + switch (type) { +#define CASE_ENUM(type, str, ctype) \ + case Type::type: \ + return #ctype; + TYPE_TABLE(CASE_ENUM) +#undef CASE_ENUM + default: + throw UnknownTypeException(type); + } +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("triton_tvm_ffi.type_to_string", TypeToString) - .def("triton_tvm_ffi.string_to_type", StringToType); + .def("triton_tvm_ffi.string_to_type", StringToType) + .def("triton_tvm_ffi.type_to_ctype", TypeToCType); } } // namespace triton_tvm_ffi diff --git a/src/utils.cc b/src/utils.cc index 65aab79..4a51305 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,5 +1,4 @@ #include "macro.h" -#include "type.h" #include #include #include