Files
FlagDoc/flaggems/fused_moe.md
2026-04-23 16:36:50 +08:00

696 lines
18 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# fused_moe.py 分析
本文整理 `src/flag_gems/fused/fused_moe.py` 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。
## 第一章MoE 与 fused MoE 的数学原理与计算原理
### 1.0 MoE 在大模型中的背景
MoEMixture of Experts是大模型中一种常见的稀疏前馈结构。它的核心思想是
- 不再让每个 token 都经过同一组固定的前馈层
- 而是准备多个 expert 子网络
- 再由 router 为每个 token 动态选择少量 expert 参与计算
在 Transformer 大模型中MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是:
```text
Attention -> Dense FFN
```
引入 MoE 后会变成:
```text
Attention -> Router + Experts
```
这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。
可以把它理解为:
- dense FFN 是“所有 token 共用同一套前馈参数”
- MoE 是“不同 token 只激活一小部分 expert 参数”
因此MoE 在大模型中的价值主要体现在:
- 提升模型容量
- 控制单 token 计算成本
- 让不同 token 能使用更有针对性的 expert 子网络
从运行时角度看MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。
更具体地说,这个算子主要做三件事:
1. 根据 router 给出的 `topk_ids``topk_weights`,确定每个 token 应该送到哪些 expert
2. 对被选中的 expert 执行前馈计算
3. 把多个 expert 的输出按路由权重聚合回 token 维度
因此从模型结构上看MoE 算子本质上是在实现:
```text
token -> route to experts -> expert MLP -> weighted combine
```
`fused_moe.py` 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化:
- token 到 expert 的重排
- expert GEMM 的执行方式
- 量化路径
- 最终聚合过程
### 1.1 MoE 的数学原理
![MoE Principle](./fused_moe_image/moe.png)
结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。
设第 $t$ 个 token 的隐藏状态表示为:
```math
x_t \in \mathbb{R}^{d}
```
其中 $d$ 为隐藏维度。MoE 层包含 $E$ 个 expert记第 $e$ 个 expert 对应的前馈函数为 $f_e$。
第一步是路由计算。Router 对输入 $x_t$ 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射:
```math
s_t = W_g^{\top} x_t + b_g
```
其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足:
```math
s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E}
```
通常可进一步写成 softmax 后的路由概率:
```math
p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E}
```
第二步是 Top-k 选择。Router 选择概率最高的 $k$ 个 expert记其索引集合为 $S_t$
```math
\mathcal{S}_t = \operatorname{TopK}(p_t, k)
```
对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 `Top-k Selection`
第三步是 dispatch。根据集合 $S_t$token $x_t$ 会被逻辑上分发到这些被选中的 expert未被选中的 expert 不参与当前 token 的计算。
第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 `Linear1 -> Activation(SiLU) -> Linear2`,则第 $e$ 个 expert 的输出可以写为:
```math
f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e}
```
其中:
- $W_{1,e}$、$W_{2,e}$ 分别为第 $e$ 个 expert 的两层线性变换参数
- $b_{1,e}$、$b_{2,e}$ 为对应偏置
- $\mathrm{SiLU}(\cdot)$ 为图中 expert 内部使用的激活函数
最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出:
```math
y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t)
```
因此,上图中的完整计算链条可以概括为:
```math
x_t
\xrightarrow{\text{Router}}
s_t
\xrightarrow{\text{Softmax}}
p_t
\xrightarrow{\text{TopK}}
\mathcal{S}_t
\xrightarrow{\text{Dispatch}}
\{f_e(x_t)\}_{e \in \mathcal{S}_t}
\xrightarrow{\text{Weighted Sum}}
y_t
```
也就是说MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。
### 1.2 fused_moe 的数学原理
把每个 `(token, topk-slot)` 看作一个 routed item $r$,则:
- 对应 token 为 $t(r)$
- 对应 expert 为 $e(r)$
- 对应路由权重为 $a(r)$
那么 `fused_moe` 实际仍然是在算同一个数学对象。
第一层:
```math
h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)}
```
激活后:
```math
z_r = \phi(h_r)
```
第二层:
```math
o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)}
```
最后对同一个 token 的 routed 输出求和:
```math
y_t = \sum_{r: t(r)=t} a(r)\,o_r
```
代码中有一个开关 `apply_router_weight_on_input`,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 `a(r)` 是标量,线性层满足:
```math
W^{\top}(a z) = a\,W^{\top}z
```
因此两种写法数学上等价。
### 1.3 fused_moe 的计算原理
`fused_moe` 和普通 MoE 在数学上等价,主要差别在工程组织方式。
普通 MoE 更像是按逻辑步骤分开执行:
1. router 选 top-k experts
2. dispatch / gather token
3. 对每个 expert 分别做 `w1`
4. 激活
5. 对每个 expert 分别做 `w2`
6. 乘 router weight 并聚合
`fused_moe` 的核心思想是:
1. 先把 routed token 按 expert 重排
2. 把每个 expert 的 token 数量按 `BLOCK_SIZE_M` 补齐
3. 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列”
4. 用统一 kernel 两次完成 `w1``w2`
5. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从 routed token 视角看kernel 做的是:
```math
C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k]
```
其中:
- $t(r)$ 是 routed item $r$ 对应的原始 token
- $e(r)$ 是 routed item $r$ 对应的 expert
- $n$ 是输出通道
在实现上,这套组织方式依赖三个关键步骤:
- `moe_align_block_size(...)`
把 routed token 按 expert 排列成规则 block
- `fused_moe_kernel(...)`
按 block 执行 expert GEMM
- `moe_sum(...)`
对每个 token 的 top-k 输出做最终求和
### 1.4 和普通 MoE 实现的区别
从数学上看,`fused_moe` 和普通 MoE 没有本质差别,仍然是在做:
```math
y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t)
```
差别主要在工程组织方式:
- 普通 MoE 更像是按逻辑步骤分开执行
- dispatch
- expert matmul
- activation
- expert matmul
- combine
- fused MoE 则更强调:
- 按 expert 重排 token
- 让一个 Triton program 处理规则 tile
- 把量化、bias、router weight 等附加逻辑尽量揉进 kernel
这样做的收益是:
- 减少 kernel launch
- 改善权重访问局部性
- 提高大规模 expert 场景下的吞吐
## 文件定位
`fused_moe.py` 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括:
- 选择或生成 MoE kernel 配置
- 对输入激活做量化预处理
- 启动 Triton MoE kernel
- 执行两层 expert MLP
- 在 top-k expert 输出之间做聚合
它依赖两个关键辅助模块:
- `flag_gems.fused.moe_align_block_size`
负责把 routed token 按 expert 重排并按 block 对齐
- `flag_gems.fused.moe_sum`
负责把每个 token 的 top-k expert 输出求和
## 对外入口
对外可直接调用的入口有两个:
- `inplace_fused_experts(...)`
- `outplace_fused_experts(...)`
两者都只是薄封装,最终都调用:
- `fused_experts_impl(...)`
其中:
- `inplace_fused_experts` 把结果直接写回 `hidden_states`
- `outplace_fused_experts` 返回新分配的输出张量
## 文件内主要函数分组
### 1. 配置选择
- `get_embedded_moe_configs()`
`fused_moe_config.yaml` 读取内嵌调优配置
- `_get_device_name()`
获取并规范化设备名
- `get_moe_configs(...)`
按设备名、expert 数、输出维度、dtype、block shape 查询配置
- `try_get_optimal_moe_config(...)`
优先使用内嵌配置,否则回退到启发式配置
- `get_default_config(...)`
默认配置生成逻辑
- `get_moe_wna16_block_config(...)`
针对 WNA16 路径生成 block 配置
- `_ensure_block_size_k_divisible(...)`
修正 `BLOCK_SIZE_K`
### 2. 数据类型与量化辅助
- `_get_config_dtype_str(...)`
把当前 dtype 和量化模式映射成配置查表 key
- `_get_config_quant_dtype(...)`
把量化模式映射成激活量化 dtype
- `_fp8_quantize(...)`
激活 FP8 量化
- `_int8_quantize(...)`
激活 INT8 量化
- `moe_kernel_quantize_input(...)`
在进入 GEMM 前对输入激活做量化
- `dequant_mxfp4(...)`
- `dequant_mxfp6(...)`
### 3. 激活函数
- `MoEActivation`
定义支持的激活枚举
- `apply_moe_activation(...)`
执行激活
- `_silu_and_mul_kernel(...)`
Triton 版 SiLU-and-mul
### 4. Triton kernel
- `write_zeros_to_output(...)`
expert 不在当前 rank 时写零
- `fused_moe_kernel(...)`
通用 MoE kernel覆盖普通浮点、FP8、INT8 路径
- `fused_moe_kernel_gptq_awq(...)`
WNA16 专用 kernel面向 GPTQ/AWQ 风格的量化权重
### 5. kernel 启动与分发
- `invoke_fused_moe_triton_kernel(...)`
启动通用 kernel
- `invoke_fused_moe_wna16_triton_kernel(...)`
启动 WNA16 kernel
- `dispatch_fused_moe_kernel(...)`
根据量化模式分发到具体启动函数
## 主调用路径
主调用链可以概括为:
```text
inplace_fused_experts / outplace_fused_experts
-> fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> chunk loop
-> moe_kernel_quantize_input (for w1 input)
-> moe_align_block_size 或 naive assignment
-> dispatch_fused_moe_kernel (for w1)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> apply_moe_activation
-> moe_kernel_quantize_input (for w2 input)
-> dispatch_fused_moe_kernel (for w2)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> moe_sum
```
如果走 WNA16 特化路径,则中间会改为:
```text
dispatch_fused_moe_kernel
-> invoke_fused_moe_wna16_triton_kernel
-> fused_moe_kernel_gptq_awq
```
## fused_experts_impl 的执行流程
`fused_experts_impl(...)` 是整个文件的调度核心。
### 1. 参数校验
它会检查:
- `hidden_states``w1``w2` 的 shape 是否匹配
- `topk_weights``topk_ids` 形状是否一致
- `hidden_states` 是否连续
- `w1/w2` 最后一维 stride 是否为 1
- 输入 dtype 是否属于:
- `torch.float32`
- `torch.float16`
- `torch.bfloat16`
当前实现还显式限制:
- `activation == "silu"`
虽然 `MoEActivation` 定义了多种激活,但当前入口只允许 SiLU 路径。
### 2. 配置选择
它先生成两个关键变量:
- `config_dtype`
用于 kernel 配置查表
- `quant_dtype`
用于激活量化逻辑
之后用 `try_get_optimal_moe_config(...)` 选择当前 chunk 对应的 Triton 配置。
### 3. 中间缓存分配
会分配三块中间缓存:
- `intermediate_cache1`
保存 `w1` 输出shape 为 `[M, top_k, N]`
- `intermediate_cache2`
保存激活后的结果shape 为 `[M * top_k, activation_out_dim]`
- `intermediate_cache3`
保存 `w2` 输出shape 为 `[M, top_k, K]`
其中 `cache13` 复用了 `cache1``cache3` 的存储,因为两者生命周期不重叠。
### 4. 计算 dtype 选择
输入 dtype 与 kernel 中的 `compute_type` 对应关系:
- `torch.bfloat16 -> tl.bfloat16`
- `torch.float16 -> tl.float16`
- `torch.float32 -> tl.float32`
### 5. 特殊量化格式处理
如果启用了 `ocp_mx_scheme`,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。
对于 `use_int8_w8a16``use_int4_w4a16`,当前公开入口中也会先把权重反量化成 `hidden_states.dtype`,随后把:
- `use_int8_w8a16 = False`
- `use_int4_w4a16 = False`
这意味着虽然文件内保留了 WNA16 特化 kernel但从 `fused_experts_impl(...)` 这条入口实际运行时,常见路径仍然是通用 `fused_moe_kernel(...)`
### 6. 分 chunk 执行
实现按 `CHUNK_SIZE = 16 * 1024` 分块处理 token原因是
- 限制中间缓存大小
- 让 kernel 配置更容易适配当前 chunk 规模
每个 chunk 内部执行:
1. 量化第一层输入
2. 把 routed token 按 expert 重排
3.`w1`
4. 激活
5. 量化第二层输入
6.`w2`
7. 对 top-k expert 输出做求和
## moe_align_block_size 的作用
MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时expert 权重访问会非常离散。
`moe_align_block_size(...)` 的作用是:
- 把所有 `(token, topk-slot)` 展平为 routed token
- 按 expert 把 routed token 排序
- 对每个 expert 的 token 数量按 `BLOCK_SIZE_M` 向上补齐
它输出三个核心量:
- `sorted_token_ids`
重排后的 routed token 下标
- `expert_ids`
每个 M 方向 block 对应的 expert id
- `num_tokens_post_padded`
补齐后的 routed token 总数
这样后续 Triton kernel 就可以让一个 program 只处理:
- 一块 routed token
- 一个固定 expert
- 一块输出列
这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。
## fused_moe_kernel 的核心原理
`fused_moe_kernel(...)` 是通用 kernel也是阅读此文件时最值得重点关注的函数。
### 1. program 到 tile 的映射
kernel 先根据:
- `BLOCK_SIZE_M`
- `BLOCK_SIZE_N`
- `GROUP_SIZE_M`
`pid` 映射到:
- `pid_m`
- `pid_n`
这里使用 grouped ordering 来改善 L2 reuse。
### 2. 当前 block 对应哪些 token 和 expert
对每个 `pid_m`
- 如果启用了重排,先从 `sorted_token_ids` 取出当前 block 的 routed token
-`expert_ids[pid_m]` 取出当前 block 对应的 expert
如果 `expert_id == -1`,则说明该 expert 不在当前 expert parallel rank直接调用 `write_zeros_to_output(...)`
### 3. 构造 A / B 指针
输入 A 的访问方式为:
```text
offs_token // top_k
```
因为 routed token 是 `(token, slot)` 的展平形式,所以要除以 `top_k` 才能回到原始 token 下标。
权重 B 则按照:
- 当前 expert
- 当前 K block
- 当前 N block
来取对应子块。
### 4. 主循环
沿 K 维做 block 累加:
- 读取 A 子块
- 读取 B 子块
- 走普通浮点或量化分支
- 累加到 `accumulator`
`accumulator` 始终先用 `fp32` 保存,再在输出前转成 `compute_type`
### 5. 后处理
主循环结束后依次执行:
- 可选 dequant
- 可选 bias
- 可选乘 router weight
- cast 到 `compute_type`
- 写回 C
## 普通浮点与 FP8 路径
### 普通浮点
当:
- `use_fp8_w8a8 == False`
- `use_int8_w8a8 == False`
- `A_scale is None`
- `B_scale is None`
kernel 直接执行普通浮点路径:
```text
accumulator += tl.dot(a, b)
```
这是 `bfloat16``float16``float32` 共享的主要逻辑。
### FP8 W8A8
当启用 `use_fp8_w8a8` 时:
- Python 侧先调用 `moe_kernel_quantize_input(...)`
- `moe_kernel_quantize_input(...)` 会进一步调用 `_fp8_quantize(...)`
- kernel 内根据量化粒度决定如何加载 scale
支持的粒度包括:
- tensor-wise
- per-channel
- block-wise
block-wise 路径下,会在每个 K block 内读取本 block 对应的:
- `a_scale`
- `b_scale`
然后执行近似:
```text
C_block ~= dot(qA_block, qB_block) * a_scale * b_scale
```
### bfloat16 支持
文件明确支持:
- `torch.bfloat16`
其执行方式不是把整个 kernel 改成 BF16 累加,而是:
- 输入 / 输出使用 BF16
- `accumulator` 先用 FP32
- 最后再 cast 回 `tl.bfloat16`
## WNA16 路径
文件内还保留了一条 WNA16 特化路径:
- `invoke_fused_moe_wna16_triton_kernel(...)`
- `fused_moe_kernel_gptq_awq(...)`
它主要用于:
- `int8_w8a16`
- `int4_w4a16`
并支持:
- group scale
- zero point
- GPTQ/AWQ 风格权重布局
但需要注意:
- 当前 `fused_experts_impl(...)` 中会先把 INT8/INT4 权重反量化成普通浮点
- 因此从当前公开入口来看,这条分支通常不会成为主路径
也就是说,代码结构上支持 WNA16 专用 kernel但入口行为上更偏向统一走通用 kernel。
## 调用图
```text
inplace_fused_experts
-> fused_experts_impl
outplace_fused_experts
-> fused_experts_impl
fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> get_moe_configs
-> _get_device_name
-> get_embedded_moe_configs
-> get_default_config
-> moe_kernel_quantize_input
-> _fp8_quantize / _int8_quantize
-> moe_align_block_size
-> dispatch_fused_moe_kernel
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> write_zeros_to_output
-> invoke_fused_moe_wna16_triton_kernel
-> get_moe_wna16_block_config
-> _ensure_block_size_k_divisible
-> fused_moe_kernel_gptq_awq
-> write_zeros_to_output
-> apply_moe_activation
-> _silu_and_mul_kernel
-> moe_sum
```
## 当前实现的几个注意点
- 当前入口只允许 `activation == "silu"`
- 明确支持输入 dtype
- `float32`
- `float16`
- `bfloat16`
- 通用主路径是 `fused_moe_kernel(...)`
- 文件内虽然保留了 WNA16 特化 kernel但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel
- `moe_align_block_size(...)` 是性能关键,它决定了 routed token 是否能被整理成规则 block
## 小结
`fused_moe.py` 的核心思想可以概括为三点:
1. 把 routed token 按 expert 排序并按 block 对齐
2. 用统一的 Triton kernel 两次完成 expert 的两层 GEMM
3. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从源码阅读角度,推荐优先关注以下函数:
- `fused_experts_impl(...)`
- `moe_kernel_quantize_input(...)`
- `dispatch_fused_moe_kernel(...)`
- `fused_moe_kernel(...)`
- `moe_align_block_size(...)`
这几处连起来,基本就构成了整个 fused MoE 前向的主干。