diff --git a/flaggems/fused_moe.md b/flaggems/fused_moe.md new file mode 100644 index 0000000..4c5eade --- /dev/null +++ b/flaggems/fused_moe.md @@ -0,0 +1,695 @@ +# fused_moe.py 分析 + +本文整理 `src/flag_gems/fused/fused_moe.py` 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。 + +## 第一章:MoE 与 fused MoE 的数学原理与计算原理 + +### 1.0 MoE 在大模型中的背景 + +MoE,Mixture of Experts,是大模型中一种常见的稀疏前馈结构。它的核心思想是: + +- 不再让每个 token 都经过同一组固定的前馈层 +- 而是准备多个 expert 子网络 +- 再由 router 为每个 token 动态选择少量 expert 参与计算 + +在 Transformer 大模型中,MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是: + +```text +Attention -> Dense FFN +``` + +引入 MoE 后会变成: + +```text +Attention -> Router + Experts +``` + +这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。 + +可以把它理解为: + +- dense FFN 是“所有 token 共用同一套前馈参数” +- MoE 是“不同 token 只激活一小部分 expert 参数” + +因此,MoE 在大模型中的价值主要体现在: + +- 提升模型容量 +- 控制单 token 计算成本 +- 让不同 token 能使用更有针对性的 expert 子网络 + +从运行时角度看,MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states,输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。 + +更具体地说,这个算子主要做三件事: + +1. 根据 router 给出的 `topk_ids` 和 `topk_weights`,确定每个 token 应该送到哪些 expert +2. 对被选中的 expert 执行前馈计算 +3. 把多个 expert 的输出按路由权重聚合回 token 维度 + +因此,从模型结构上看,MoE 算子本质上是在实现: + +```text +token -> route to experts -> expert MLP -> weighted combine +``` + +而 `fused_moe.py` 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化: + +- token 到 expert 的重排 +- expert GEMM 的执行方式 +- 量化路径 +- 最终聚合过程 + +### 1.1 MoE 的数学原理 + +![MoE Principle](./fused_moe_image/moe.png) + +结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。 + +设第 $t$ 个 token 的隐藏状态表示为: + +```math +x_t \in \mathbb{R}^{d} +``` + +其中 $d$ 为隐藏维度。MoE 层包含 $E$ 个 expert,记第 $e$ 个 expert 对应的前馈函数为 $f_e$。 + +第一步是路由计算。Router 对输入 $x_t$ 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射: + +```math +s_t = W_g^{\top} x_t + b_g +``` + +其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足: + +```math +s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E} +``` + +通常可进一步写成 softmax 后的路由概率: + +```math +p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E} +``` + +第二步是 Top-k 选择。Router 选择概率最高的 $k$ 个 expert,记其索引集合为 $S_t$: + +```math +\mathcal{S}_t = \operatorname{TopK}(p_t, k) +``` + +对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 `Top-k Selection`。 + +第三步是 dispatch。根据集合 $S_t$,token $x_t$ 会被逻辑上分发到这些被选中的 expert;未被选中的 expert 不参与当前 token 的计算。 + +第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 `Linear1 -> Activation(SiLU) -> Linear2`,则第 $e$ 个 expert 的输出可以写为: + +```math +f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e} +``` + +其中: + +- $W_{1,e}$、$W_{2,e}$ 分别为第 $e$ 个 expert 的两层线性变换参数 +- $b_{1,e}$、$b_{2,e}$ 为对应偏置 +- $\mathrm{SiLU}(\cdot)$ 为图中 expert 内部使用的激活函数 + +最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出: + +```math +y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t) +``` + +因此,上图中的完整计算链条可以概括为: + +```math +x_t +\xrightarrow{\text{Router}} +s_t +\xrightarrow{\text{Softmax}} +p_t +\xrightarrow{\text{TopK}} +\mathcal{S}_t +\xrightarrow{\text{Dispatch}} +\{f_e(x_t)\}_{e \in \mathcal{S}_t} +\xrightarrow{\text{Weighted Sum}} +y_t +``` + +也就是说,MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。 + +### 1.2 fused_moe 的数学原理 + +把每个 `(token, topk-slot)` 看作一个 routed item $r$,则: + +- 对应 token 为 $t(r)$ +- 对应 expert 为 $e(r)$ +- 对应路由权重为 $a(r)$ + +那么 `fused_moe` 实际仍然是在算同一个数学对象。 + +第一层: + +```math +h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)} +``` + +激活后: + +```math +z_r = \phi(h_r) +``` + +第二层: + +```math +o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)} +``` + +最后对同一个 token 的 routed 输出求和: + +```math +y_t = \sum_{r: t(r)=t} a(r)\,o_r +``` + +代码中有一个开关 `apply_router_weight_on_input`,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 `a(r)` 是标量,线性层满足: + +```math +W^{\top}(a z) = a\,W^{\top}z +``` + +因此两种写法数学上等价。 + +### 1.3 fused_moe 的计算原理 + +`fused_moe` 和普通 MoE 在数学上等价,主要差别在工程组织方式。 + +普通 MoE 更像是按逻辑步骤分开执行: + +1. router 选 top-k experts +2. dispatch / gather token +3. 对每个 expert 分别做 `w1` +4. 激活 +5. 对每个 expert 分别做 `w2` +6. 乘 router weight 并聚合 + +而 `fused_moe` 的核心思想是: + +1. 先把 routed token 按 expert 重排 +2. 把每个 expert 的 token 数量按 `BLOCK_SIZE_M` 补齐 +3. 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列” +4. 用统一 kernel 两次完成 `w1` 和 `w2` +5. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑 + +从 routed token 视角看,kernel 做的是: + +```math +C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k] +``` + +其中: + +- $t(r)$ 是 routed item $r$ 对应的原始 token +- $e(r)$ 是 routed item $r$ 对应的 expert +- $n$ 是输出通道 + +在实现上,这套组织方式依赖三个关键步骤: + +- `moe_align_block_size(...)` + 把 routed token 按 expert 排列成规则 block +- `fused_moe_kernel(...)` + 按 block 执行 expert GEMM +- `moe_sum(...)` + 对每个 token 的 top-k 输出做最终求和 + +### 1.4 和普通 MoE 实现的区别 + +从数学上看,`fused_moe` 和普通 MoE 没有本质差别,仍然是在做: + +```math +y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t) +``` + +差别主要在工程组织方式: + +- 普通 MoE 更像是按逻辑步骤分开执行 + - dispatch + - expert matmul + - activation + - expert matmul + - combine +- fused MoE 则更强调: + - 按 expert 重排 token + - 让一个 Triton program 处理规则 tile + - 把量化、bias、router weight 等附加逻辑尽量揉进 kernel + +这样做的收益是: + +- 减少 kernel launch +- 改善权重访问局部性 +- 提高大规模 expert 场景下的吞吐 + +## 文件定位 + +`fused_moe.py` 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括: + +- 选择或生成 MoE kernel 配置 +- 对输入激活做量化预处理 +- 启动 Triton MoE kernel +- 执行两层 expert MLP +- 在 top-k expert 输出之间做聚合 + +它依赖两个关键辅助模块: + +- `flag_gems.fused.moe_align_block_size` + 负责把 routed token 按 expert 重排并按 block 对齐 +- `flag_gems.fused.moe_sum` + 负责把每个 token 的 top-k expert 输出求和 + +## 对外入口 + +对外可直接调用的入口有两个: + +- `inplace_fused_experts(...)` +- `outplace_fused_experts(...)` + +两者都只是薄封装,最终都调用: + +- `fused_experts_impl(...)` + +其中: + +- `inplace_fused_experts` 把结果直接写回 `hidden_states` +- `outplace_fused_experts` 返回新分配的输出张量 + +## 文件内主要函数分组 + +### 1. 配置选择 + +- `get_embedded_moe_configs()` + 从 `fused_moe_config.yaml` 读取内嵌调优配置 +- `_get_device_name()` + 获取并规范化设备名 +- `get_moe_configs(...)` + 按设备名、expert 数、输出维度、dtype、block shape 查询配置 +- `try_get_optimal_moe_config(...)` + 优先使用内嵌配置,否则回退到启发式配置 +- `get_default_config(...)` + 默认配置生成逻辑 +- `get_moe_wna16_block_config(...)` + 针对 WNA16 路径生成 block 配置 +- `_ensure_block_size_k_divisible(...)` + 修正 `BLOCK_SIZE_K` + +### 2. 数据类型与量化辅助 + +- `_get_config_dtype_str(...)` + 把当前 dtype 和量化模式映射成配置查表 key +- `_get_config_quant_dtype(...)` + 把量化模式映射成激活量化 dtype +- `_fp8_quantize(...)` + 激活 FP8 量化 +- `_int8_quantize(...)` + 激活 INT8 量化 +- `moe_kernel_quantize_input(...)` + 在进入 GEMM 前对输入激活做量化 +- `dequant_mxfp4(...)` +- `dequant_mxfp6(...)` + +### 3. 激活函数 + +- `MoEActivation` + 定义支持的激活枚举 +- `apply_moe_activation(...)` + 执行激活 +- `_silu_and_mul_kernel(...)` + Triton 版 SiLU-and-mul + +### 4. Triton kernel + +- `write_zeros_to_output(...)` + expert 不在当前 rank 时写零 +- `fused_moe_kernel(...)` + 通用 MoE kernel,覆盖普通浮点、FP8、INT8 路径 +- `fused_moe_kernel_gptq_awq(...)` + WNA16 专用 kernel,面向 GPTQ/AWQ 风格的量化权重 + +### 5. kernel 启动与分发 + +- `invoke_fused_moe_triton_kernel(...)` + 启动通用 kernel +- `invoke_fused_moe_wna16_triton_kernel(...)` + 启动 WNA16 kernel +- `dispatch_fused_moe_kernel(...)` + 根据量化模式分发到具体启动函数 + +## 主调用路径 + +主调用链可以概括为: + +```text +inplace_fused_experts / outplace_fused_experts + -> fused_experts_impl + -> _get_config_dtype_str + -> _get_config_quant_dtype + -> try_get_optimal_moe_config + -> chunk loop + -> moe_kernel_quantize_input (for w1 input) + -> moe_align_block_size 或 naive assignment + -> dispatch_fused_moe_kernel (for w1) + -> invoke_fused_moe_triton_kernel + -> fused_moe_kernel + -> apply_moe_activation + -> moe_kernel_quantize_input (for w2 input) + -> dispatch_fused_moe_kernel (for w2) + -> invoke_fused_moe_triton_kernel + -> fused_moe_kernel + -> moe_sum +``` + +如果走 WNA16 特化路径,则中间会改为: + +```text +dispatch_fused_moe_kernel + -> invoke_fused_moe_wna16_triton_kernel + -> fused_moe_kernel_gptq_awq +``` + +## fused_experts_impl 的执行流程 + +`fused_experts_impl(...)` 是整个文件的调度核心。 + +### 1. 参数校验 + +它会检查: + +- `hidden_states`、`w1`、`w2` 的 shape 是否匹配 +- `topk_weights` 和 `topk_ids` 形状是否一致 +- `hidden_states` 是否连续 +- `w1/w2` 最后一维 stride 是否为 1 +- 输入 dtype 是否属于: + - `torch.float32` + - `torch.float16` + - `torch.bfloat16` + +当前实现还显式限制: + +- `activation == "silu"` + +虽然 `MoEActivation` 定义了多种激活,但当前入口只允许 SiLU 路径。 + +### 2. 配置选择 + +它先生成两个关键变量: + +- `config_dtype` + 用于 kernel 配置查表 +- `quant_dtype` + 用于激活量化逻辑 + +之后用 `try_get_optimal_moe_config(...)` 选择当前 chunk 对应的 Triton 配置。 + +### 3. 中间缓存分配 + +会分配三块中间缓存: + +- `intermediate_cache1` + 保存 `w1` 输出,shape 为 `[M, top_k, N]` +- `intermediate_cache2` + 保存激活后的结果,shape 为 `[M * top_k, activation_out_dim]` +- `intermediate_cache3` + 保存 `w2` 输出,shape 为 `[M, top_k, K]` + +其中 `cache13` 复用了 `cache1` 和 `cache3` 的存储,因为两者生命周期不重叠。 + +### 4. 计算 dtype 选择 + +输入 dtype 与 kernel 中的 `compute_type` 对应关系: + +- `torch.bfloat16 -> tl.bfloat16` +- `torch.float16 -> tl.float16` +- `torch.float32 -> tl.float32` + +### 5. 特殊量化格式处理 + +如果启用了 `ocp_mx_scheme`,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。 + +对于 `use_int8_w8a16` 和 `use_int4_w4a16`,当前公开入口中也会先把权重反量化成 `hidden_states.dtype`,随后把: + +- `use_int8_w8a16 = False` +- `use_int4_w4a16 = False` + +这意味着虽然文件内保留了 WNA16 特化 kernel,但从 `fused_experts_impl(...)` 这条入口实际运行时,常见路径仍然是通用 `fused_moe_kernel(...)`。 + +### 6. 分 chunk 执行 + +实现按 `CHUNK_SIZE = 16 * 1024` 分块处理 token,原因是: + +- 限制中间缓存大小 +- 让 kernel 配置更容易适配当前 chunk 规模 + +每个 chunk 内部执行: + +1. 量化第一层输入 +2. 把 routed token 按 expert 重排 +3. 跑 `w1` +4. 激活 +5. 量化第二层输入 +6. 跑 `w2` +7. 对 top-k expert 输出做求和 + +## moe_align_block_size 的作用 + +MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时,expert 权重访问会非常离散。 + +`moe_align_block_size(...)` 的作用是: + +- 把所有 `(token, topk-slot)` 展平为 routed token +- 按 expert 把 routed token 排序 +- 对每个 expert 的 token 数量按 `BLOCK_SIZE_M` 向上补齐 + +它输出三个核心量: + +- `sorted_token_ids` + 重排后的 routed token 下标 +- `expert_ids` + 每个 M 方向 block 对应的 expert id +- `num_tokens_post_padded` + 补齐后的 routed token 总数 + +这样后续 Triton kernel 就可以让一个 program 只处理: + +- 一块 routed token +- 一个固定 expert +- 一块输出列 + +这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。 + +## fused_moe_kernel 的核心原理 + +`fused_moe_kernel(...)` 是通用 kernel,也是阅读此文件时最值得重点关注的函数。 + +### 1. program 到 tile 的映射 + +kernel 先根据: + +- `BLOCK_SIZE_M` +- `BLOCK_SIZE_N` +- `GROUP_SIZE_M` + +把 `pid` 映射到: + +- `pid_m` +- `pid_n` + +这里使用 grouped ordering 来改善 L2 reuse。 + +### 2. 当前 block 对应哪些 token 和 expert + +对每个 `pid_m`: + +- 如果启用了重排,先从 `sorted_token_ids` 取出当前 block 的 routed token +- 从 `expert_ids[pid_m]` 取出当前 block 对应的 expert + +如果 `expert_id == -1`,则说明该 expert 不在当前 expert parallel rank,直接调用 `write_zeros_to_output(...)`。 + +### 3. 构造 A / B 指针 + +输入 A 的访问方式为: + +```text +offs_token // top_k +``` + +因为 routed token 是 `(token, slot)` 的展平形式,所以要除以 `top_k` 才能回到原始 token 下标。 + +权重 B 则按照: + +- 当前 expert +- 当前 K block +- 当前 N block + +来取对应子块。 + +### 4. 主循环 + +沿 K 维做 block 累加: + +- 读取 A 子块 +- 读取 B 子块 +- 走普通浮点或量化分支 +- 累加到 `accumulator` + +`accumulator` 始终先用 `fp32` 保存,再在输出前转成 `compute_type`。 + +### 5. 后处理 + +主循环结束后依次执行: + +- 可选 dequant +- 可选 bias +- 可选乘 router weight +- cast 到 `compute_type` +- 写回 C + +## 普通浮点与 FP8 路径 + +### 普通浮点 + +当: + +- `use_fp8_w8a8 == False` +- `use_int8_w8a8 == False` +- `A_scale is None` +- `B_scale is None` + +时,kernel 直接执行普通浮点路径: + +```text +accumulator += tl.dot(a, b) +``` + +这是 `bfloat16`、`float16`、`float32` 共享的主要逻辑。 + +### FP8 W8A8 + +当启用 `use_fp8_w8a8` 时: + +- Python 侧先调用 `moe_kernel_quantize_input(...)` +- `moe_kernel_quantize_input(...)` 会进一步调用 `_fp8_quantize(...)` +- kernel 内根据量化粒度决定如何加载 scale + +支持的粒度包括: + +- tensor-wise +- per-channel +- block-wise + +block-wise 路径下,会在每个 K block 内读取本 block 对应的: + +- `a_scale` +- `b_scale` + +然后执行近似: + +```text +C_block ~= dot(qA_block, qB_block) * a_scale * b_scale +``` + +### bfloat16 支持 + +文件明确支持: + +- `torch.bfloat16` + +其执行方式不是把整个 kernel 改成 BF16 累加,而是: + +- 输入 / 输出使用 BF16 +- `accumulator` 先用 FP32 +- 最后再 cast 回 `tl.bfloat16` + +## WNA16 路径 + +文件内还保留了一条 WNA16 特化路径: + +- `invoke_fused_moe_wna16_triton_kernel(...)` +- `fused_moe_kernel_gptq_awq(...)` + +它主要用于: + +- `int8_w8a16` +- `int4_w4a16` + +并支持: + +- group scale +- zero point +- GPTQ/AWQ 风格权重布局 + +但需要注意: + +- 当前 `fused_experts_impl(...)` 中会先把 INT8/INT4 权重反量化成普通浮点 +- 因此从当前公开入口来看,这条分支通常不会成为主路径 + +也就是说,代码结构上支持 WNA16 专用 kernel,但入口行为上更偏向统一走通用 kernel。 + +## 调用图 + +```text +inplace_fused_experts + -> fused_experts_impl + +outplace_fused_experts + -> fused_experts_impl + +fused_experts_impl + -> _get_config_dtype_str + -> _get_config_quant_dtype + -> try_get_optimal_moe_config + -> get_moe_configs + -> _get_device_name + -> get_embedded_moe_configs + -> get_default_config + -> moe_kernel_quantize_input + -> _fp8_quantize / _int8_quantize + -> moe_align_block_size + -> dispatch_fused_moe_kernel + -> invoke_fused_moe_triton_kernel + -> fused_moe_kernel + -> write_zeros_to_output + -> invoke_fused_moe_wna16_triton_kernel + -> get_moe_wna16_block_config + -> _ensure_block_size_k_divisible + -> fused_moe_kernel_gptq_awq + -> write_zeros_to_output + -> apply_moe_activation + -> _silu_and_mul_kernel + -> moe_sum +``` + +## 当前实现的几个注意点 + +- 当前入口只允许 `activation == "silu"` +- 明确支持输入 dtype: + - `float32` + - `float16` + - `bfloat16` +- 通用主路径是 `fused_moe_kernel(...)` +- 文件内虽然保留了 WNA16 特化 kernel,但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel +- `moe_align_block_size(...)` 是性能关键,它决定了 routed token 是否能被整理成规则 block + +## 小结 + +`fused_moe.py` 的核心思想可以概括为三点: + +1. 把 routed token 按 expert 排序并按 block 对齐 +2. 用统一的 Triton kernel 两次完成 expert 的两层 GEMM +3. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑 + +从源码阅读角度,推荐优先关注以下函数: + +- `fused_experts_impl(...)` +- `moe_kernel_quantize_input(...)` +- `dispatch_fused_moe_kernel(...)` +- `fused_moe_kernel(...)` +- `moe_align_block_size(...)` + +这几处连起来,基本就构成了整个 fused MoE 前向的主干。 diff --git a/flaggems/fused_moe_image/fused_moe.png b/flaggems/fused_moe_image/fused_moe.png new file mode 100644 index 0000000..ce1396e Binary files /dev/null and b/flaggems/fused_moe_image/fused_moe.png differ diff --git a/flaggems/fused_moe_image/moe.png b/flaggems/fused_moe_image/moe.png new file mode 100644 index 0000000..fb9ade6 Binary files /dev/null and b/flaggems/fused_moe_image/moe.png differ