add fused_moe analysis

This commit is contained in:
2026-04-23 16:36:50 +08:00
parent bbcc7e3935
commit 9addbf2a78
3 changed files with 695 additions and 0 deletions

695
flaggems/fused_moe.md Normal file
View File

@@ -0,0 +1,695 @@
# fused_moe.py 分析
本文整理 `src/flag_gems/fused/fused_moe.py` 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。
## 第一章MoE 与 fused MoE 的数学原理与计算原理
### 1.0 MoE 在大模型中的背景
MoEMixture of Experts是大模型中一种常见的稀疏前馈结构。它的核心思想是
- 不再让每个 token 都经过同一组固定的前馈层
- 而是准备多个 expert 子网络
- 再由 router 为每个 token 动态选择少量 expert 参与计算
在 Transformer 大模型中MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是:
```text
Attention -> Dense FFN
```
引入 MoE 后会变成:
```text
Attention -> Router + Experts
```
这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。
可以把它理解为:
- dense FFN 是“所有 token 共用同一套前馈参数”
- MoE 是“不同 token 只激活一小部分 expert 参数”
因此MoE 在大模型中的价值主要体现在:
- 提升模型容量
- 控制单 token 计算成本
- 让不同 token 能使用更有针对性的 expert 子网络
从运行时角度看MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。
更具体地说,这个算子主要做三件事:
1. 根据 router 给出的 `topk_ids``topk_weights`,确定每个 token 应该送到哪些 expert
2. 对被选中的 expert 执行前馈计算
3. 把多个 expert 的输出按路由权重聚合回 token 维度
因此从模型结构上看MoE 算子本质上是在实现:
```text
token -> route to experts -> expert MLP -> weighted combine
```
`fused_moe.py` 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化:
- token 到 expert 的重排
- expert GEMM 的执行方式
- 量化路径
- 最终聚合过程
### 1.1 MoE 的数学原理
![MoE Principle](./fused_moe_image/moe.png)
结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。
设第 $t$ 个 token 的隐藏状态表示为:
```math
x_t \in \mathbb{R}^{d}
```
其中 $d$ 为隐藏维度。MoE 层包含 $E$ 个 expert记第 $e$ 个 expert 对应的前馈函数为 $f_e$。
第一步是路由计算。Router 对输入 $x_t$ 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射:
```math
s_t = W_g^{\top} x_t + b_g
```
其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足:
```math
s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E}
```
通常可进一步写成 softmax 后的路由概率:
```math
p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E}
```
第二步是 Top-k 选择。Router 选择概率最高的 $k$ 个 expert记其索引集合为 $S_t$
```math
\mathcal{S}_t = \operatorname{TopK}(p_t, k)
```
对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 `Top-k Selection`
第三步是 dispatch。根据集合 $S_t$token $x_t$ 会被逻辑上分发到这些被选中的 expert未被选中的 expert 不参与当前 token 的计算。
第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 `Linear1 -> Activation(SiLU) -> Linear2`,则第 $e$ 个 expert 的输出可以写为:
```math
f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e}
```
其中:
- $W_{1,e}$、$W_{2,e}$ 分别为第 $e$ 个 expert 的两层线性变换参数
- $b_{1,e}$、$b_{2,e}$ 为对应偏置
- $\mathrm{SiLU}(\cdot)$ 为图中 expert 内部使用的激活函数
最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出:
```math
y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t)
```
因此,上图中的完整计算链条可以概括为:
```math
x_t
\xrightarrow{\text{Router}}
s_t
\xrightarrow{\text{Softmax}}
p_t
\xrightarrow{\text{TopK}}
\mathcal{S}_t
\xrightarrow{\text{Dispatch}}
\{f_e(x_t)\}_{e \in \mathcal{S}_t}
\xrightarrow{\text{Weighted Sum}}
y_t
```
也就是说MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。
### 1.2 fused_moe 的数学原理
把每个 `(token, topk-slot)` 看作一个 routed item $r$,则:
- 对应 token 为 $t(r)$
- 对应 expert 为 $e(r)$
- 对应路由权重为 $a(r)$
那么 `fused_moe` 实际仍然是在算同一个数学对象。
第一层:
```math
h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)}
```
激活后:
```math
z_r = \phi(h_r)
```
第二层:
```math
o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)}
```
最后对同一个 token 的 routed 输出求和:
```math
y_t = \sum_{r: t(r)=t} a(r)\,o_r
```
代码中有一个开关 `apply_router_weight_on_input`,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 `a(r)` 是标量,线性层满足:
```math
W^{\top}(a z) = a\,W^{\top}z
```
因此两种写法数学上等价。
### 1.3 fused_moe 的计算原理
`fused_moe` 和普通 MoE 在数学上等价,主要差别在工程组织方式。
普通 MoE 更像是按逻辑步骤分开执行:
1. router 选 top-k experts
2. dispatch / gather token
3. 对每个 expert 分别做 `w1`
4. 激活
5. 对每个 expert 分别做 `w2`
6. 乘 router weight 并聚合
`fused_moe` 的核心思想是:
1. 先把 routed token 按 expert 重排
2. 把每个 expert 的 token 数量按 `BLOCK_SIZE_M` 补齐
3. 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列”
4. 用统一 kernel 两次完成 `w1``w2`
5. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从 routed token 视角看kernel 做的是:
```math
C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k]
```
其中:
- $t(r)$ 是 routed item $r$ 对应的原始 token
- $e(r)$ 是 routed item $r$ 对应的 expert
- $n$ 是输出通道
在实现上,这套组织方式依赖三个关键步骤:
- `moe_align_block_size(...)`
把 routed token 按 expert 排列成规则 block
- `fused_moe_kernel(...)`
按 block 执行 expert GEMM
- `moe_sum(...)`
对每个 token 的 top-k 输出做最终求和
### 1.4 和普通 MoE 实现的区别
从数学上看,`fused_moe` 和普通 MoE 没有本质差别,仍然是在做:
```math
y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t)
```
差别主要在工程组织方式:
- 普通 MoE 更像是按逻辑步骤分开执行
- dispatch
- expert matmul
- activation
- expert matmul
- combine
- fused MoE 则更强调:
- 按 expert 重排 token
- 让一个 Triton program 处理规则 tile
- 把量化、bias、router weight 等附加逻辑尽量揉进 kernel
这样做的收益是:
- 减少 kernel launch
- 改善权重访问局部性
- 提高大规模 expert 场景下的吞吐
## 文件定位
`fused_moe.py` 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括:
- 选择或生成 MoE kernel 配置
- 对输入激活做量化预处理
- 启动 Triton MoE kernel
- 执行两层 expert MLP
- 在 top-k expert 输出之间做聚合
它依赖两个关键辅助模块:
- `flag_gems.fused.moe_align_block_size`
负责把 routed token 按 expert 重排并按 block 对齐
- `flag_gems.fused.moe_sum`
负责把每个 token 的 top-k expert 输出求和
## 对外入口
对外可直接调用的入口有两个:
- `inplace_fused_experts(...)`
- `outplace_fused_experts(...)`
两者都只是薄封装,最终都调用:
- `fused_experts_impl(...)`
其中:
- `inplace_fused_experts` 把结果直接写回 `hidden_states`
- `outplace_fused_experts` 返回新分配的输出张量
## 文件内主要函数分组
### 1. 配置选择
- `get_embedded_moe_configs()`
`fused_moe_config.yaml` 读取内嵌调优配置
- `_get_device_name()`
获取并规范化设备名
- `get_moe_configs(...)`
按设备名、expert 数、输出维度、dtype、block shape 查询配置
- `try_get_optimal_moe_config(...)`
优先使用内嵌配置,否则回退到启发式配置
- `get_default_config(...)`
默认配置生成逻辑
- `get_moe_wna16_block_config(...)`
针对 WNA16 路径生成 block 配置
- `_ensure_block_size_k_divisible(...)`
修正 `BLOCK_SIZE_K`
### 2. 数据类型与量化辅助
- `_get_config_dtype_str(...)`
把当前 dtype 和量化模式映射成配置查表 key
- `_get_config_quant_dtype(...)`
把量化模式映射成激活量化 dtype
- `_fp8_quantize(...)`
激活 FP8 量化
- `_int8_quantize(...)`
激活 INT8 量化
- `moe_kernel_quantize_input(...)`
在进入 GEMM 前对输入激活做量化
- `dequant_mxfp4(...)`
- `dequant_mxfp6(...)`
### 3. 激活函数
- `MoEActivation`
定义支持的激活枚举
- `apply_moe_activation(...)`
执行激活
- `_silu_and_mul_kernel(...)`
Triton 版 SiLU-and-mul
### 4. Triton kernel
- `write_zeros_to_output(...)`
expert 不在当前 rank 时写零
- `fused_moe_kernel(...)`
通用 MoE kernel覆盖普通浮点、FP8、INT8 路径
- `fused_moe_kernel_gptq_awq(...)`
WNA16 专用 kernel面向 GPTQ/AWQ 风格的量化权重
### 5. kernel 启动与分发
- `invoke_fused_moe_triton_kernel(...)`
启动通用 kernel
- `invoke_fused_moe_wna16_triton_kernel(...)`
启动 WNA16 kernel
- `dispatch_fused_moe_kernel(...)`
根据量化模式分发到具体启动函数
## 主调用路径
主调用链可以概括为:
```text
inplace_fused_experts / outplace_fused_experts
-> fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> chunk loop
-> moe_kernel_quantize_input (for w1 input)
-> moe_align_block_size 或 naive assignment
-> dispatch_fused_moe_kernel (for w1)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> apply_moe_activation
-> moe_kernel_quantize_input (for w2 input)
-> dispatch_fused_moe_kernel (for w2)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> moe_sum
```
如果走 WNA16 特化路径,则中间会改为:
```text
dispatch_fused_moe_kernel
-> invoke_fused_moe_wna16_triton_kernel
-> fused_moe_kernel_gptq_awq
```
## fused_experts_impl 的执行流程
`fused_experts_impl(...)` 是整个文件的调度核心。
### 1. 参数校验
它会检查:
- `hidden_states``w1``w2` 的 shape 是否匹配
- `topk_weights``topk_ids` 形状是否一致
- `hidden_states` 是否连续
- `w1/w2` 最后一维 stride 是否为 1
- 输入 dtype 是否属于:
- `torch.float32`
- `torch.float16`
- `torch.bfloat16`
当前实现还显式限制:
- `activation == "silu"`
虽然 `MoEActivation` 定义了多种激活,但当前入口只允许 SiLU 路径。
### 2. 配置选择
它先生成两个关键变量:
- `config_dtype`
用于 kernel 配置查表
- `quant_dtype`
用于激活量化逻辑
之后用 `try_get_optimal_moe_config(...)` 选择当前 chunk 对应的 Triton 配置。
### 3. 中间缓存分配
会分配三块中间缓存:
- `intermediate_cache1`
保存 `w1` 输出shape 为 `[M, top_k, N]`
- `intermediate_cache2`
保存激活后的结果shape 为 `[M * top_k, activation_out_dim]`
- `intermediate_cache3`
保存 `w2` 输出shape 为 `[M, top_k, K]`
其中 `cache13` 复用了 `cache1``cache3` 的存储,因为两者生命周期不重叠。
### 4. 计算 dtype 选择
输入 dtype 与 kernel 中的 `compute_type` 对应关系:
- `torch.bfloat16 -> tl.bfloat16`
- `torch.float16 -> tl.float16`
- `torch.float32 -> tl.float32`
### 5. 特殊量化格式处理
如果启用了 `ocp_mx_scheme`,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。
对于 `use_int8_w8a16``use_int4_w4a16`,当前公开入口中也会先把权重反量化成 `hidden_states.dtype`,随后把:
- `use_int8_w8a16 = False`
- `use_int4_w4a16 = False`
这意味着虽然文件内保留了 WNA16 特化 kernel但从 `fused_experts_impl(...)` 这条入口实际运行时,常见路径仍然是通用 `fused_moe_kernel(...)`
### 6. 分 chunk 执行
实现按 `CHUNK_SIZE = 16 * 1024` 分块处理 token原因是
- 限制中间缓存大小
- 让 kernel 配置更容易适配当前 chunk 规模
每个 chunk 内部执行:
1. 量化第一层输入
2. 把 routed token 按 expert 重排
3.`w1`
4. 激活
5. 量化第二层输入
6.`w2`
7. 对 top-k expert 输出做求和
## moe_align_block_size 的作用
MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时expert 权重访问会非常离散。
`moe_align_block_size(...)` 的作用是:
- 把所有 `(token, topk-slot)` 展平为 routed token
- 按 expert 把 routed token 排序
- 对每个 expert 的 token 数量按 `BLOCK_SIZE_M` 向上补齐
它输出三个核心量:
- `sorted_token_ids`
重排后的 routed token 下标
- `expert_ids`
每个 M 方向 block 对应的 expert id
- `num_tokens_post_padded`
补齐后的 routed token 总数
这样后续 Triton kernel 就可以让一个 program 只处理:
- 一块 routed token
- 一个固定 expert
- 一块输出列
这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。
## fused_moe_kernel 的核心原理
`fused_moe_kernel(...)` 是通用 kernel也是阅读此文件时最值得重点关注的函数。
### 1. program 到 tile 的映射
kernel 先根据:
- `BLOCK_SIZE_M`
- `BLOCK_SIZE_N`
- `GROUP_SIZE_M`
`pid` 映射到:
- `pid_m`
- `pid_n`
这里使用 grouped ordering 来改善 L2 reuse。
### 2. 当前 block 对应哪些 token 和 expert
对每个 `pid_m`
- 如果启用了重排,先从 `sorted_token_ids` 取出当前 block 的 routed token
-`expert_ids[pid_m]` 取出当前 block 对应的 expert
如果 `expert_id == -1`,则说明该 expert 不在当前 expert parallel rank直接调用 `write_zeros_to_output(...)`
### 3. 构造 A / B 指针
输入 A 的访问方式为:
```text
offs_token // top_k
```
因为 routed token 是 `(token, slot)` 的展平形式,所以要除以 `top_k` 才能回到原始 token 下标。
权重 B 则按照:
- 当前 expert
- 当前 K block
- 当前 N block
来取对应子块。
### 4. 主循环
沿 K 维做 block 累加:
- 读取 A 子块
- 读取 B 子块
- 走普通浮点或量化分支
- 累加到 `accumulator`
`accumulator` 始终先用 `fp32` 保存,再在输出前转成 `compute_type`
### 5. 后处理
主循环结束后依次执行:
- 可选 dequant
- 可选 bias
- 可选乘 router weight
- cast 到 `compute_type`
- 写回 C
## 普通浮点与 FP8 路径
### 普通浮点
当:
- `use_fp8_w8a8 == False`
- `use_int8_w8a8 == False`
- `A_scale is None`
- `B_scale is None`
kernel 直接执行普通浮点路径:
```text
accumulator += tl.dot(a, b)
```
这是 `bfloat16``float16``float32` 共享的主要逻辑。
### FP8 W8A8
当启用 `use_fp8_w8a8` 时:
- Python 侧先调用 `moe_kernel_quantize_input(...)`
- `moe_kernel_quantize_input(...)` 会进一步调用 `_fp8_quantize(...)`
- kernel 内根据量化粒度决定如何加载 scale
支持的粒度包括:
- tensor-wise
- per-channel
- block-wise
block-wise 路径下,会在每个 K block 内读取本 block 对应的:
- `a_scale`
- `b_scale`
然后执行近似:
```text
C_block ~= dot(qA_block, qB_block) * a_scale * b_scale
```
### bfloat16 支持
文件明确支持:
- `torch.bfloat16`
其执行方式不是把整个 kernel 改成 BF16 累加,而是:
- 输入 / 输出使用 BF16
- `accumulator` 先用 FP32
- 最后再 cast 回 `tl.bfloat16`
## WNA16 路径
文件内还保留了一条 WNA16 特化路径:
- `invoke_fused_moe_wna16_triton_kernel(...)`
- `fused_moe_kernel_gptq_awq(...)`
它主要用于:
- `int8_w8a16`
- `int4_w4a16`
并支持:
- group scale
- zero point
- GPTQ/AWQ 风格权重布局
但需要注意:
- 当前 `fused_experts_impl(...)` 中会先把 INT8/INT4 权重反量化成普通浮点
- 因此从当前公开入口来看,这条分支通常不会成为主路径
也就是说,代码结构上支持 WNA16 专用 kernel但入口行为上更偏向统一走通用 kernel。
## 调用图
```text
inplace_fused_experts
-> fused_experts_impl
outplace_fused_experts
-> fused_experts_impl
fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> get_moe_configs
-> _get_device_name
-> get_embedded_moe_configs
-> get_default_config
-> moe_kernel_quantize_input
-> _fp8_quantize / _int8_quantize
-> moe_align_block_size
-> dispatch_fused_moe_kernel
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> write_zeros_to_output
-> invoke_fused_moe_wna16_triton_kernel
-> get_moe_wna16_block_config
-> _ensure_block_size_k_divisible
-> fused_moe_kernel_gptq_awq
-> write_zeros_to_output
-> apply_moe_activation
-> _silu_and_mul_kernel
-> moe_sum
```
## 当前实现的几个注意点
- 当前入口只允许 `activation == "silu"`
- 明确支持输入 dtype
- `float32`
- `float16`
- `bfloat16`
- 通用主路径是 `fused_moe_kernel(...)`
- 文件内虽然保留了 WNA16 特化 kernel但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel
- `moe_align_block_size(...)` 是性能关键,它决定了 routed token 是否能被整理成规则 block
## 小结
`fused_moe.py` 的核心思想可以概括为三点:
1. 把 routed token 按 expert 排序并按 block 对齐
2. 用统一的 Triton kernel 两次完成 expert 的两层 GEMM
3. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从源码阅读角度,推荐优先关注以下函数:
- `fused_experts_impl(...)`
- `moe_kernel_quantize_input(...)`
- `dispatch_fused_moe_kernel(...)`
- `fused_moe_kernel(...)`
- `moe_align_block_size(...)`
这几处连起来,基本就构成了整个 fused MoE 前向的主干。

Binary file not shown.

After

Width:  |  Height:  |  Size: 870 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 871 KiB