#include "exception.h" #include #include #include #include #define CUDA_CHECK(code) \ do { \ if ((code) != CUDA_SUCCESS) { \ throw triton_tvm_ffi::CUDAException(code); \ } \ } while (false) tvm::ffi::Map GetDeviceProperties(int device_id) { tvm::ffi::cuda_api::DeviceHandle device; CUDA_CHECK(cuDeviceGet(&device, device_id)); int maxSharedMem = 0; int maxNumRegs = 0; int multiprocessorCount = 0; int warpSize = 0; int smClockRate = 0; int memClockRate = 0; int memBusWidth = 0; CUDA_CHECK(cuDeviceGetAttribute( &maxSharedMem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); CUDA_CHECK(cuDeviceGetAttribute( &maxNumRegs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); CUDA_CHECK(cuDeviceGetAttribute( &multiprocessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); CUDA_CHECK( cuDeviceGetAttribute(&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); CUDA_CHECK(cuDeviceGetAttribute(&smClockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); CUDA_CHECK(cuDeviceGetAttribute( &memClockRate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); CUDA_CHECK(cuDeviceGetAttribute( &memBusWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); return {{"max_shared_mem", maxSharedMem}, {"max_num_regs", maxNumRegs}, {"multiprocessor_count", multiprocessorCount}, {"warpSize", warpSize}, {"sm_clock_rate", smClockRate}, {"mem_clock_rate", memClockRate}, {"mem_bus_width", memBusWidth}}; } tvm::ffi::Tuple LoadBinary(const tvm::ffi::String &name, const tvm::ffi::Bytes &data, int32_t shared, CUdevice device) { CUcontext pctx; CUfunction fun; CUmodule mod; int32_t nRegs = 0; int32_t nSpills = 0; int32_t nMaxThreads = 0; int32_t sharedOptin = 0; CUDA_CHECK(cuCtxGetCurrent(&pctx)); if (!pctx) { CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); CUDA_CHECK(cuCtxSetCurrent(pctx)); } CUDA_CHECK(cuModuleLoadData(&mod, data.data())); CUDA_CHECK(cuModuleGetFunction(&fun, mod, name.data())); CUDA_CHECK(cuFuncGetAttribute(&nRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); CUDA_CHECK( cuFuncGetAttribute(&nSpills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); CUDA_CHECK(cuFuncGetAttribute(&nMaxThreads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); CUDA_CHECK(cuDeviceGetAttribute( &sharedOptin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); static constexpr int64_t kExpectedMaxDynamicSharedMemory = 49152; if (shared > kExpectedMaxDynamicSharedMemory && sharedOptin > kExpectedMaxDynamicSharedMemory) { CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); int32_t sharedTotal = 0, sharedStatic = 0; CUDA_CHECK(cuDeviceGetAttribute( &sharedTotal, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device)); CUDA_CHECK(cuFuncGetAttribute(&sharedStatic, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); CUDA_CHECK( cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, sharedOptin - sharedStatic)); } return tvm::ffi::Tuple{ mod, fun, nRegs, nSpills, nMaxThreads}; } void Launch(tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) { CUtensorMap x; int32_t gridX = args[0].cast(); int32_t gridY = args[1].cast(); int32_t gridZ = args[2].cast(); CUstream stream = reinterpret_cast(args[3].cast()); CUfunction function = reinterpret_cast(args[4].cast()); tvm::ffi::Tuple kernelMetadata = args[5].cast>(); int32_t numWarps = kernelMetadata.get<0>(); int32_t numCtas = kernelMetadata.get<1>(); int32_t sharedMemory = kernelMetadata.get<2>(); tvm::ffi::ObjectRef launchMetadata = args[6].cast(); tvm::ffi::ObjectRef launchEnterHook = args[7].cast(); tvm::ffi::ObjectRef launchExitHook = args[8].cast(); bool launchCooperativeGrid = args[9].cast(); bool launchPdl = args[10].cast(); tvm::ffi::ObjectRef globalScratchObject = args[11].cast(); tvm::ffi::ObjectRef profileScratchObject = args[12].cast(); tvm::ffi::PackedArgs kernelArgs = args.Slice(13); // TODO: call `launchEnterHook` // TODO: check `globalScratchObject` CUdeviceptr globalScratch = 0; // TODO: check `profileScratchObject` CUdeviceptr profileScratch = 0; if (gridX * gridY * gridZ > 0) { CUlaunchAttribute launchAttr[4]; CUlaunchConfig config; config.gridDimX = gridX * numCtas; config.gridDimY = gridY; config.gridDimZ = gridZ; static constexpr int32_t kThreadsPerWarp = 32; config.blockDimX = kThreadsPerWarp * numWarps; config.blockDimY = 1; config.blockDimZ = 1; config.sharedMemBytes = sharedMemory; config.hStream = stream; config.attrs = launchAttr; int32_t numAttrs = 0; // TODO: check `launchPdf` // TODO: check `launchCooperativeGrid` if (numCtas != 1) { CUlaunchAttribute clusterAttr; clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; clusterAttr.value.clusterDim.x = numCtas; clusterAttr.value.clusterDim.y = 1; clusterAttr.value.clusterDim.z = 1; launchAttr[numAttrs++] = clusterAttr; CUlaunchAttribute clusterSchedulingAttr; clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; launchAttr[numAttrs++] = clusterSchedulingAttr; } config.numAttrs = numAttrs; if (numCtas == 16) { CUDA_CHECK(cuFuncSetAttribute( function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); } const int32_t kernelArgNum = kernelArgs.size(); void **params = reinterpret_cast(alloca(sizeof(void *) * (kernelArgNum + 2))); for (size_t i = 0; i < kernelArgNum; ++i) { tvm::ffi::AnyView arg = kernelArgs[i]; if (auto val = arg.try_cast()) { void **ptr = reinterpret_cast(alloca(sizeof(void *))); *ptr = val->data_ptr(); params[i] = ptr; } else if (auto val = arg.try_cast()) { int32_t *ptr = reinterpret_cast(alloca(sizeof(int32_t))); *ptr = *val; params[i] = ptr; } else if (auto val = arg.try_cast()) { float *ptr = reinterpret_cast(alloca(sizeof(float))); *ptr = *val; params[i] = ptr; } else { assert(false && "unsupported kernel argument type"); } } params[kernelArgNum] = &globalScratch; params[kernelArgNum + 1] = &profileScratch; CUDA_CHECK(cuLaunchKernelEx(&config, function, params, nullptr)); } // TODO: call `launchExitHook` } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("triton_tvm_ffi.utils.get_device_properties", GetDeviceProperties) .def_packed("triton_tvm_ffi.utils.launch", Launch) .def("triton_tvm_ffi.utils.load_binary", LoadBinary); }