# fused_moe.py 分析 本文整理 `src/flag_gems/fused/fused_moe.py` 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。 ## 第一章:MoE 与 fused MoE 的数学原理与计算原理 ### 1.0 MoE 在大模型中的背景 MoE,Mixture of Experts,是大模型中一种常见的稀疏前馈结构。它的核心思想是: - 不再让每个 token 都经过同一组固定的前馈层 - 而是准备多个 expert 子网络 - 再由 router 为每个 token 动态选择少量 expert 参与计算 在 Transformer 大模型中,MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是: ```text Attention -> Dense FFN ``` 引入 MoE 后会变成: ```text Attention -> Router + Experts ``` 这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。 可以把它理解为: - dense FFN 是“所有 token 共用同一套前馈参数” - MoE 是“不同 token 只激活一小部分 expert 参数” 因此,MoE 在大模型中的价值主要体现在: - 提升模型容量 - 控制单 token 计算成本 - 让不同 token 能使用更有针对性的 expert 子网络 从运行时角度看,MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states,输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。 更具体地说,这个算子主要做三件事: 1. 根据 router 给出的 `topk_ids` 和 `topk_weights`,确定每个 token 应该送到哪些 expert 2. 对被选中的 expert 执行前馈计算 3. 把多个 expert 的输出按路由权重聚合回 token 维度 因此,从模型结构上看,MoE 算子本质上是在实现: ```text token -> route to experts -> expert MLP -> weighted combine ``` 而 `fused_moe.py` 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化: - token 到 expert 的重排 - expert GEMM 的执行方式 - 量化路径 - 最终聚合过程 ### 1.1 MoE 的数学原理 ![MoE Principle](./fused_moe_image/moe.png) 结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。 设第 $t$ 个 token 的隐藏状态表示为: ```math x_t \in \mathbb{R}^{d} ``` 其中 $d$ 为隐藏维度。MoE 层包含 $E$ 个 expert,记第 $e$ 个 expert 对应的前馈函数为 $f_e$。 第一步是路由计算。Router 对输入 $x_t$ 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射: ```math s_t = W_g^{\top} x_t + b_g ``` 其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足: ```math s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E} ``` 通常可进一步写成 softmax 后的路由概率: ```math p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E} ``` 第二步是 Top-k 选择。Router 选择概率最高的 $k$ 个 expert,记其索引集合为 $S_t$: ```math \mathcal{S}_t = \operatorname{TopK}(p_t, k) ``` 对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 `Top-k Selection`。 第三步是 dispatch。根据集合 $S_t$,token $x_t$ 会被逻辑上分发到这些被选中的 expert;未被选中的 expert 不参与当前 token 的计算。 第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 `Linear1 -> Activation(SiLU) -> Linear2`,则第 $e$ 个 expert 的输出可以写为: ```math f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e} ``` 其中: - $W_{1,e}$、$W_{2,e}$ 分别为第 $e$ 个 expert 的两层线性变换参数 - $b_{1,e}$、$b_{2,e}$ 为对应偏置 - $\mathrm{SiLU}(\cdot)$ 为图中 expert 内部使用的激活函数 最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出: ```math y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t) ``` 因此,上图中的完整计算链条可以概括为: ```math x_t \xrightarrow{\text{Router}} s_t \xrightarrow{\text{Softmax}} p_t \xrightarrow{\text{TopK}} \mathcal{S}_t \xrightarrow{\text{Dispatch}} \{f_e(x_t)\}_{e \in \mathcal{S}_t} \xrightarrow{\text{Weighted Sum}} y_t ``` 也就是说,MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。 ### 1.2 fused_moe 的数学原理 把每个 `(token, topk-slot)` 看作一个 routed item $r$,则: - 对应 token 为 $t(r)$ - 对应 expert 为 $e(r)$ - 对应路由权重为 $a(r)$ 那么 `fused_moe` 实际仍然是在算同一个数学对象。 第一层: ```math h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)} ``` 激活后: ```math z_r = \phi(h_r) ``` 第二层: ```math o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)} ``` 最后对同一个 token 的 routed 输出求和: ```math y_t = \sum_{r: t(r)=t} a(r)\,o_r ``` 代码中有一个开关 `apply_router_weight_on_input`,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 `a(r)` 是标量,线性层满足: ```math W^{\top}(a z) = a\,W^{\top}z ``` 因此两种写法数学上等价。 ### 1.3 fused_moe 的计算原理 `fused_moe` 和普通 MoE 在数学上等价,主要差别在工程组织方式。 普通 MoE 更像是按逻辑步骤分开执行: 1. router 选 top-k experts 2. dispatch / gather token 3. 对每个 expert 分别做 `w1` 4. 激活 5. 对每个 expert 分别做 `w2` 6. 乘 router weight 并聚合 而 `fused_moe` 的核心思想是: 1. 先把 routed token 按 expert 重排 2. 把每个 expert 的 token 数量按 `BLOCK_SIZE_M` 补齐 3. 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列” 4. 用统一 kernel 两次完成 `w1` 和 `w2` 5. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑 从 routed token 视角看,kernel 做的是: ```math C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k] ``` 其中: - $t(r)$ 是 routed item $r$ 对应的原始 token - $e(r)$ 是 routed item $r$ 对应的 expert - $n$ 是输出通道 在实现上,这套组织方式依赖三个关键步骤: - `moe_align_block_size(...)` 把 routed token 按 expert 排列成规则 block - `fused_moe_kernel(...)` 按 block 执行 expert GEMM - `moe_sum(...)` 对每个 token 的 top-k 输出做最终求和 ### 1.4 和普通 MoE 实现的区别 从数学上看,`fused_moe` 和普通 MoE 没有本质差别,仍然是在做: ```math y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t) ``` 差别主要在工程组织方式: - 普通 MoE 更像是按逻辑步骤分开执行 - dispatch - expert matmul - activation - expert matmul - combine - fused MoE 则更强调: - 按 expert 重排 token - 让一个 Triton program 处理规则 tile - 把量化、bias、router weight 等附加逻辑尽量揉进 kernel 这样做的收益是: - 减少 kernel launch - 改善权重访问局部性 - 提高大规模 expert 场景下的吞吐 ## 文件定位 `fused_moe.py` 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括: - 选择或生成 MoE kernel 配置 - 对输入激活做量化预处理 - 启动 Triton MoE kernel - 执行两层 expert MLP - 在 top-k expert 输出之间做聚合 它依赖两个关键辅助模块: - `flag_gems.fused.moe_align_block_size` 负责把 routed token 按 expert 重排并按 block 对齐 - `flag_gems.fused.moe_sum` 负责把每个 token 的 top-k expert 输出求和 ## 对外入口 对外可直接调用的入口有两个: - `inplace_fused_experts(...)` - `outplace_fused_experts(...)` 两者都只是薄封装,最终都调用: - `fused_experts_impl(...)` 其中: - `inplace_fused_experts` 把结果直接写回 `hidden_states` - `outplace_fused_experts` 返回新分配的输出张量 ## 文件内主要函数分组 ### 1. 配置选择 - `get_embedded_moe_configs()` 从 `fused_moe_config.yaml` 读取内嵌调优配置 - `_get_device_name()` 获取并规范化设备名 - `get_moe_configs(...)` 按设备名、expert 数、输出维度、dtype、block shape 查询配置 - `try_get_optimal_moe_config(...)` 优先使用内嵌配置,否则回退到启发式配置 - `get_default_config(...)` 默认配置生成逻辑 - `get_moe_wna16_block_config(...)` 针对 WNA16 路径生成 block 配置 - `_ensure_block_size_k_divisible(...)` 修正 `BLOCK_SIZE_K` ### 2. 数据类型与量化辅助 - `_get_config_dtype_str(...)` 把当前 dtype 和量化模式映射成配置查表 key - `_get_config_quant_dtype(...)` 把量化模式映射成激活量化 dtype - `_fp8_quantize(...)` 激活 FP8 量化 - `_int8_quantize(...)` 激活 INT8 量化 - `moe_kernel_quantize_input(...)` 在进入 GEMM 前对输入激活做量化 - `dequant_mxfp4(...)` - `dequant_mxfp6(...)` ### 3. 激活函数 - `MoEActivation` 定义支持的激活枚举 - `apply_moe_activation(...)` 执行激活 - `_silu_and_mul_kernel(...)` Triton 版 SiLU-and-mul ### 4. Triton kernel - `write_zeros_to_output(...)` expert 不在当前 rank 时写零 - `fused_moe_kernel(...)` 通用 MoE kernel,覆盖普通浮点、FP8、INT8 路径 - `fused_moe_kernel_gptq_awq(...)` WNA16 专用 kernel,面向 GPTQ/AWQ 风格的量化权重 ### 5. kernel 启动与分发 - `invoke_fused_moe_triton_kernel(...)` 启动通用 kernel - `invoke_fused_moe_wna16_triton_kernel(...)` 启动 WNA16 kernel - `dispatch_fused_moe_kernel(...)` 根据量化模式分发到具体启动函数 ## 主调用路径 主调用链可以概括为: ```text inplace_fused_experts / outplace_fused_experts -> fused_experts_impl -> _get_config_dtype_str -> _get_config_quant_dtype -> try_get_optimal_moe_config -> chunk loop -> moe_kernel_quantize_input (for w1 input) -> moe_align_block_size 或 naive assignment -> dispatch_fused_moe_kernel (for w1) -> invoke_fused_moe_triton_kernel -> fused_moe_kernel -> apply_moe_activation -> moe_kernel_quantize_input (for w2 input) -> dispatch_fused_moe_kernel (for w2) -> invoke_fused_moe_triton_kernel -> fused_moe_kernel -> moe_sum ``` 如果走 WNA16 特化路径,则中间会改为: ```text dispatch_fused_moe_kernel -> invoke_fused_moe_wna16_triton_kernel -> fused_moe_kernel_gptq_awq ``` ## fused_experts_impl 的执行流程 `fused_experts_impl(...)` 是整个文件的调度核心。 ### 1. 参数校验 它会检查: - `hidden_states`、`w1`、`w2` 的 shape 是否匹配 - `topk_weights` 和 `topk_ids` 形状是否一致 - `hidden_states` 是否连续 - `w1/w2` 最后一维 stride 是否为 1 - 输入 dtype 是否属于: - `torch.float32` - `torch.float16` - `torch.bfloat16` 当前实现还显式限制: - `activation == "silu"` 虽然 `MoEActivation` 定义了多种激活,但当前入口只允许 SiLU 路径。 ### 2. 配置选择 它先生成两个关键变量: - `config_dtype` 用于 kernel 配置查表 - `quant_dtype` 用于激活量化逻辑 之后用 `try_get_optimal_moe_config(...)` 选择当前 chunk 对应的 Triton 配置。 ### 3. 中间缓存分配 会分配三块中间缓存: - `intermediate_cache1` 保存 `w1` 输出,shape 为 `[M, top_k, N]` - `intermediate_cache2` 保存激活后的结果,shape 为 `[M * top_k, activation_out_dim]` - `intermediate_cache3` 保存 `w2` 输出,shape 为 `[M, top_k, K]` 其中 `cache13` 复用了 `cache1` 和 `cache3` 的存储,因为两者生命周期不重叠。 ### 4. 计算 dtype 选择 输入 dtype 与 kernel 中的 `compute_type` 对应关系: - `torch.bfloat16 -> tl.bfloat16` - `torch.float16 -> tl.float16` - `torch.float32 -> tl.float32` ### 5. 特殊量化格式处理 如果启用了 `ocp_mx_scheme`,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。 对于 `use_int8_w8a16` 和 `use_int4_w4a16`,当前公开入口中也会先把权重反量化成 `hidden_states.dtype`,随后把: - `use_int8_w8a16 = False` - `use_int4_w4a16 = False` 这意味着虽然文件内保留了 WNA16 特化 kernel,但从 `fused_experts_impl(...)` 这条入口实际运行时,常见路径仍然是通用 `fused_moe_kernel(...)`。 ### 6. 分 chunk 执行 实现按 `CHUNK_SIZE = 16 * 1024` 分块处理 token,原因是: - 限制中间缓存大小 - 让 kernel 配置更容易适配当前 chunk 规模 每个 chunk 内部执行: 1. 量化第一层输入 2. 把 routed token 按 expert 重排 3. 跑 `w1` 4. 激活 5. 量化第二层输入 6. 跑 `w2` 7. 对 top-k expert 输出做求和 ## moe_align_block_size 的作用 MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时,expert 权重访问会非常离散。 `moe_align_block_size(...)` 的作用是: - 把所有 `(token, topk-slot)` 展平为 routed token - 按 expert 把 routed token 排序 - 对每个 expert 的 token 数量按 `BLOCK_SIZE_M` 向上补齐 它输出三个核心量: - `sorted_token_ids` 重排后的 routed token 下标 - `expert_ids` 每个 M 方向 block 对应的 expert id - `num_tokens_post_padded` 补齐后的 routed token 总数 这样后续 Triton kernel 就可以让一个 program 只处理: - 一块 routed token - 一个固定 expert - 一块输出列 这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。 ## fused_moe_kernel 的核心原理 `fused_moe_kernel(...)` 是通用 kernel,也是阅读此文件时最值得重点关注的函数。 ### 1. program 到 tile 的映射 kernel 先根据: - `BLOCK_SIZE_M` - `BLOCK_SIZE_N` - `GROUP_SIZE_M` 把 `pid` 映射到: - `pid_m` - `pid_n` 这里使用 grouped ordering 来改善 L2 reuse。 ### 2. 当前 block 对应哪些 token 和 expert 对每个 `pid_m`: - 如果启用了重排,先从 `sorted_token_ids` 取出当前 block 的 routed token - 从 `expert_ids[pid_m]` 取出当前 block 对应的 expert 如果 `expert_id == -1`,则说明该 expert 不在当前 expert parallel rank,直接调用 `write_zeros_to_output(...)`。 ### 3. 构造 A / B 指针 输入 A 的访问方式为: ```text offs_token // top_k ``` 因为 routed token 是 `(token, slot)` 的展平形式,所以要除以 `top_k` 才能回到原始 token 下标。 权重 B 则按照: - 当前 expert - 当前 K block - 当前 N block 来取对应子块。 ### 4. 主循环 沿 K 维做 block 累加: - 读取 A 子块 - 读取 B 子块 - 走普通浮点或量化分支 - 累加到 `accumulator` `accumulator` 始终先用 `fp32` 保存,再在输出前转成 `compute_type`。 ### 5. 后处理 主循环结束后依次执行: - 可选 dequant - 可选 bias - 可选乘 router weight - cast 到 `compute_type` - 写回 C ## 普通浮点与 FP8 路径 ### 普通浮点 当: - `use_fp8_w8a8 == False` - `use_int8_w8a8 == False` - `A_scale is None` - `B_scale is None` 时,kernel 直接执行普通浮点路径: ```text accumulator += tl.dot(a, b) ``` 这是 `bfloat16`、`float16`、`float32` 共享的主要逻辑。 ### FP8 W8A8 当启用 `use_fp8_w8a8` 时: - Python 侧先调用 `moe_kernel_quantize_input(...)` - `moe_kernel_quantize_input(...)` 会进一步调用 `_fp8_quantize(...)` - kernel 内根据量化粒度决定如何加载 scale 支持的粒度包括: - tensor-wise - per-channel - block-wise block-wise 路径下,会在每个 K block 内读取本 block 对应的: - `a_scale` - `b_scale` 然后执行近似: ```text C_block ~= dot(qA_block, qB_block) * a_scale * b_scale ``` ### bfloat16 支持 文件明确支持: - `torch.bfloat16` 其执行方式不是把整个 kernel 改成 BF16 累加,而是: - 输入 / 输出使用 BF16 - `accumulator` 先用 FP32 - 最后再 cast 回 `tl.bfloat16` ## WNA16 路径 文件内还保留了一条 WNA16 特化路径: - `invoke_fused_moe_wna16_triton_kernel(...)` - `fused_moe_kernel_gptq_awq(...)` 它主要用于: - `int8_w8a16` - `int4_w4a16` 并支持: - group scale - zero point - GPTQ/AWQ 风格权重布局 但需要注意: - 当前 `fused_experts_impl(...)` 中会先把 INT8/INT4 权重反量化成普通浮点 - 因此从当前公开入口来看,这条分支通常不会成为主路径 也就是说,代码结构上支持 WNA16 专用 kernel,但入口行为上更偏向统一走通用 kernel。 ## 调用图 ```text inplace_fused_experts -> fused_experts_impl outplace_fused_experts -> fused_experts_impl fused_experts_impl -> _get_config_dtype_str -> _get_config_quant_dtype -> try_get_optimal_moe_config -> get_moe_configs -> _get_device_name -> get_embedded_moe_configs -> get_default_config -> moe_kernel_quantize_input -> _fp8_quantize / _int8_quantize -> moe_align_block_size -> dispatch_fused_moe_kernel -> invoke_fused_moe_triton_kernel -> fused_moe_kernel -> write_zeros_to_output -> invoke_fused_moe_wna16_triton_kernel -> get_moe_wna16_block_config -> _ensure_block_size_k_divisible -> fused_moe_kernel_gptq_awq -> write_zeros_to_output -> apply_moe_activation -> _silu_and_mul_kernel -> moe_sum ``` ## 当前实现的几个注意点 - 当前入口只允许 `activation == "silu"` - 明确支持输入 dtype: - `float32` - `float16` - `bfloat16` - 通用主路径是 `fused_moe_kernel(...)` - 文件内虽然保留了 WNA16 特化 kernel,但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel - `moe_align_block_size(...)` 是性能关键,它决定了 routed token 是否能被整理成规则 block ## 小结 `fused_moe.py` 的核心思想可以概括为三点: 1. 把 routed token 按 expert 排序并按 block 对齐 2. 用统一的 Triton kernel 两次完成 expert 的两层 GEMM 3. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑 从源码阅读角度,推荐优先关注以下函数: - `fused_experts_impl(...)` - `moe_kernel_quantize_input(...)` - `dispatch_fused_moe_kernel(...)` - `fused_moe_kernel(...)` - `moe_align_block_size(...)` 这几处连起来,基本就构成了整个 fused MoE 前向的主干。