mirror of
https://github.com/sgjzfzzf/triton-tvm-ffi.git
synced 2026-07-01 08:51:56 +08:00
@@ -37,7 +37,7 @@ class TVMFFILauncherImpl(_ffi_Object):
|
||||
if TYPE_CHECKING:
|
||||
@staticmethod
|
||||
def __c_ffi_init__(_0: Sequence[int], _1: bool, _2: bool, /) -> Object: ...
|
||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: tuple[int, int, int], _7: Object, _8: Object, _9: Object, _10: Object, _11: Object, _12: Sequence[Any], /) -> None: ...
|
||||
def launch(self, _1: int, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: Sequence[Any], /) -> None: ...
|
||||
# fmt: on
|
||||
# tvm-ffi-stubgen(end)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import Any, Callable, Final, List, Sequence, Type
|
||||
from typing import Any, Final, List, Type
|
||||
|
||||
import jinja2
|
||||
from triton.backends.nvidia.driver import CudaDriver
|
||||
@@ -24,19 +24,16 @@ class TVMLauncher(object):
|
||||
self.profile_scratch_align: Final[int] = metadata.profile_scratch_align
|
||||
self.launch_cooperative_grid: Final[bool] = metadata.launch_cooperative_grid
|
||||
self.launch_pdl: Final[bool] = metadata.launch_pdl
|
||||
self.enable_jit: Final[bool] = (
|
||||
os.getenv("TRITON_TVM_FFI_ENABLE_JIT", None) is not None
|
||||
)
|
||||
if self.enable_jit:
|
||||
mod = tvm_ffi.cpp.load_inline(
|
||||
if os.getenv("TRITON_TVM_FFI_ENABLE_JIT", "off").lower() in {"1", "true", "on"}:
|
||||
mod: tvm_ffi.Module = tvm_ffi.cpp.load_inline(
|
||||
"launch",
|
||||
cpp_sources=self.codegen,
|
||||
cpp_sources=[self.codegen],
|
||||
extra_ldflags=["-Wl,--no-as-needed", "-lcuda"],
|
||||
extra_include_paths=[
|
||||
f"{tvm_ffi.cpp.extension._find_cuda_home()}/include"
|
||||
],
|
||||
)
|
||||
launch = mod.get_function("launch")
|
||||
launch: tvm_ffi.Function = mod.get_function("launch")
|
||||
self.launch = launch
|
||||
else:
|
||||
self.impl: TVMFFILauncherImpl = TVMFFILauncherImpl(
|
||||
@@ -50,10 +47,9 @@ class TVMLauncher(object):
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args: self.impl.launch(
|
||||
@@ -62,10 +58,9 @@ class TVMLauncher(object):
|
||||
grid_z,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
args,
|
||||
@@ -101,37 +96,36 @@ class TVMLauncher(object):
|
||||
self.profile_scratch_align,
|
||||
_allocation._profile_allocator,
|
||||
)
|
||||
assert not self.launch_cooperative_grid
|
||||
assert not self.launch_pdl
|
||||
|
||||
if self.enable_jit:
|
||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||
return self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
return self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
kernel_metadata,
|
||||
launch_metadata,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
global_scratch,
|
||||
profile_scratch,
|
||||
*args,
|
||||
)
|
||||
def canonicalize(obj: Any) -> int:
|
||||
if obj is None:
|
||||
return 0
|
||||
elif isinstance(obj, int):
|
||||
return obj
|
||||
elif get_ptr := getattr(obj, "data_ptr", None):
|
||||
return get_ptr()
|
||||
else:
|
||||
raise TypeError(f"cannot canonicalize object of type {type(obj)}")
|
||||
|
||||
(num_warps, num_ctas, shared_memory) = kernel_metadata
|
||||
if launch_enter_hook:
|
||||
launch_enter_hook(launch_metadata)
|
||||
ret = self.launch(
|
||||
gridX,
|
||||
gridY,
|
||||
gridZ,
|
||||
stream,
|
||||
function,
|
||||
num_warps,
|
||||
num_ctas,
|
||||
shared_memory,
|
||||
canonicalize(global_scratch),
|
||||
canonicalize(profile_scratch),
|
||||
*args,
|
||||
)
|
||||
if launch_exit_hook:
|
||||
launch_exit_hook(launch_metadata)
|
||||
return ret
|
||||
|
||||
@cached_property
|
||||
def codegen(self) -> str:
|
||||
|
||||
@@ -14,10 +14,8 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
int32_t numWarps = args[5].v_int64;
|
||||
int32_t numCtas = args[6].v_int64;
|
||||
int32_t sharedMemory = args[7].v_int64;
|
||||
// TODO: Implement the launch logic
|
||||
CUdeviceptr globalScratch = 0;
|
||||
// TODO: check `profileScratchObject`
|
||||
CUdeviceptr profileScratch = 0;
|
||||
uint64_t globalScratch = args[8].v_uint64;
|
||||
uint64_t profileScratch = args[9].v_uint64;
|
||||
if (gridX * gridY * gridZ > 0) {
|
||||
CUlaunchAttribute launchAttr[4];
|
||||
CUlaunchConfig config;
|
||||
@@ -51,9 +49,9 @@ TVM_FFI_DLL_EXPORT void __tvm_ffi_launch(void *handle, const TVMFFIAny *args, in
|
||||
config.numAttrs = numAttrs;
|
||||
{% for type in signature %}
|
||||
{% if type == "void *" %}
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 8 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{{ type }} arg{{ loop.index0 }} = ((DLTensor*)(args[{{ loop.index0 + 10 }}].v_c_str + sizeof(TVMFFIObject)))->data;
|
||||
{% elif type == "int32_t" %}
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 8 }}].v_int64;
|
||||
{{ type }} arg{{ loop.index0 }} = args[{{ loop.index0 + 10 }}].v_int64;
|
||||
{% else %}
|
||||
assert(false, "unsupported type yet {{ type }}");
|
||||
{% endif %}
|
||||
|
||||
Reference in New Issue
Block a user