18 KiB
fused_moe.py 分析
本文整理 src/flag_gems/fused/fused_moe.py 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。
第一章:MoE 与 fused MoE 的数学原理与计算原理
1.0 MoE 在大模型中的背景
MoE,Mixture of Experts,是大模型中一种常见的稀疏前馈结构。它的核心思想是:
- 不再让每个 token 都经过同一组固定的前馈层
- 而是准备多个 expert 子网络
- 再由 router 为每个 token 动态选择少量 expert 参与计算
在 Transformer 大模型中,MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是:
Attention -> Dense FFN
引入 MoE 后会变成:
Attention -> Router + Experts
这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。
可以把它理解为:
- dense FFN 是“所有 token 共用同一套前馈参数”
- MoE 是“不同 token 只激活一小部分 expert 参数”
因此,MoE 在大模型中的价值主要体现在:
- 提升模型容量
- 控制单 token 计算成本
- 让不同 token 能使用更有针对性的 expert 子网络
从运行时角度看,MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states,输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。
更具体地说,这个算子主要做三件事:
- 根据 router 给出的
topk_ids和topk_weights,确定每个 token 应该送到哪些 expert - 对被选中的 expert 执行前馈计算
- 把多个 expert 的输出按路由权重聚合回 token 维度
因此,从模型结构上看,MoE 算子本质上是在实现:
token -> route to experts -> expert MLP -> weighted combine
而 fused_moe.py 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化:
- token 到 expert 的重排
- expert GEMM 的执行方式
- 量化路径
- 最终聚合过程
1.1 MoE 的数学原理
结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。
设第 t 个 token 的隐藏状态表示为:
x_t \in \mathbb{R}^{d}
其中 d 为隐藏维度。MoE 层包含 E 个 expert,记第 e 个 expert 对应的前馈函数为 $f_e$。
第一步是路由计算。Router 对输入 x_t 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射:
s_t = W_g^{\top} x_t + b_g
其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足:
s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E}
通常可进一步写成 softmax 后的路由概率:
p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E}
第二步是 Top-k 选择。Router 选择概率最高的 k 个 expert,记其索引集合为 $S_t$:
\mathcal{S}_t = \operatorname{TopK}(p_t, k)
对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 Top-k Selection。
第三步是 dispatch。根据集合 $S_t$,token x_t 会被逻辑上分发到这些被选中的 expert;未被选中的 expert 不参与当前 token 的计算。
第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 Linear1 -> Activation(SiLU) -> Linear2,则第 e 个 expert 的输出可以写为:
f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e}
其中:
- $W_{1,e}$、
W_{2,e}分别为第e个 expert 的两层线性变换参数 - $b_{1,e}$、
b_{2,e}为对应偏置 \mathrm{SiLU}(\cdot)为图中 expert 内部使用的激活函数
最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出:
y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t)
因此,上图中的完整计算链条可以概括为:
x_t
\xrightarrow{\text{Router}}
s_t
\xrightarrow{\text{Softmax}}
p_t
\xrightarrow{\text{TopK}}
\mathcal{S}_t
\xrightarrow{\text{Dispatch}}
\{f_e(x_t)\}_{e \in \mathcal{S}_t}
\xrightarrow{\text{Weighted Sum}}
y_t
也就是说,MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。
1.2 fused_moe 的数学原理
把每个 (token, topk-slot) 看作一个 routed item $r$,则:
- 对应 token 为
t(r) - 对应 expert 为
e(r) - 对应路由权重为
a(r)
那么 fused_moe 实际仍然是在算同一个数学对象。
第一层:
h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)}
激活后:
z_r = \phi(h_r)
第二层:
o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)}
最后对同一个 token 的 routed 输出求和:
y_t = \sum_{r: t(r)=t} a(r)\,o_r
代码中有一个开关 apply_router_weight_on_input,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 a(r) 是标量,线性层满足:
W^{\top}(a z) = a\,W^{\top}z
因此两种写法数学上等价。
1.3 fused_moe 的计算原理
fused_moe 和普通 MoE 在数学上等价,主要差别在工程组织方式。
普通 MoE 更像是按逻辑步骤分开执行:
- router 选 top-k experts
- dispatch / gather token
- 对每个 expert 分别做
w1 - 激活
- 对每个 expert 分别做
w2 - 乘 router weight 并聚合
而 fused_moe 的核心思想是:
- 先把 routed token 按 expert 重排
- 把每个 expert 的 token 数量按
BLOCK_SIZE_M补齐 - 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列”
- 用统一 kernel 两次完成
w1和w2 - 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从 routed token 视角看,kernel 做的是:
C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k]
其中:
t(r)是 routed itemr对应的原始 tokene(r)是 routed itemr对应的 expertn是输出通道
在实现上,这套组织方式依赖三个关键步骤:
moe_align_block_size(...)把 routed token 按 expert 排列成规则 blockfused_moe_kernel(...)按 block 执行 expert GEMMmoe_sum(...)对每个 token 的 top-k 输出做最终求和
1.4 和普通 MoE 实现的区别
从数学上看,fused_moe 和普通 MoE 没有本质差别,仍然是在做:
y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t)
差别主要在工程组织方式:
- 普通 MoE 更像是按逻辑步骤分开执行
- dispatch
- expert matmul
- activation
- expert matmul
- combine
- fused MoE 则更强调:
- 按 expert 重排 token
- 让一个 Triton program 处理规则 tile
- 把量化、bias、router weight 等附加逻辑尽量揉进 kernel
这样做的收益是:
- 减少 kernel launch
- 改善权重访问局部性
- 提高大规模 expert 场景下的吞吐
文件定位
fused_moe.py 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括:
- 选择或生成 MoE kernel 配置
- 对输入激活做量化预处理
- 启动 Triton MoE kernel
- 执行两层 expert MLP
- 在 top-k expert 输出之间做聚合
它依赖两个关键辅助模块:
flag_gems.fused.moe_align_block_size负责把 routed token 按 expert 重排并按 block 对齐flag_gems.fused.moe_sum负责把每个 token 的 top-k expert 输出求和
对外入口
对外可直接调用的入口有两个:
inplace_fused_experts(...)outplace_fused_experts(...)
两者都只是薄封装,最终都调用:
fused_experts_impl(...)
其中:
inplace_fused_experts把结果直接写回hidden_statesoutplace_fused_experts返回新分配的输出张量
文件内主要函数分组
1. 配置选择
get_embedded_moe_configs()从fused_moe_config.yaml读取内嵌调优配置_get_device_name()获取并规范化设备名get_moe_configs(...)按设备名、expert 数、输出维度、dtype、block shape 查询配置try_get_optimal_moe_config(...)优先使用内嵌配置,否则回退到启发式配置get_default_config(...)默认配置生成逻辑get_moe_wna16_block_config(...)针对 WNA16 路径生成 block 配置_ensure_block_size_k_divisible(...)修正BLOCK_SIZE_K
2. 数据类型与量化辅助
_get_config_dtype_str(...)把当前 dtype 和量化模式映射成配置查表 key_get_config_quant_dtype(...)把量化模式映射成激活量化 dtype_fp8_quantize(...)激活 FP8 量化_int8_quantize(...)激活 INT8 量化moe_kernel_quantize_input(...)在进入 GEMM 前对输入激活做量化dequant_mxfp4(...)dequant_mxfp6(...)
3. 激活函数
MoEActivation定义支持的激活枚举apply_moe_activation(...)执行激活_silu_and_mul_kernel(...)Triton 版 SiLU-and-mul
4. Triton kernel
write_zeros_to_output(...)expert 不在当前 rank 时写零fused_moe_kernel(...)通用 MoE kernel,覆盖普通浮点、FP8、INT8 路径fused_moe_kernel_gptq_awq(...)WNA16 专用 kernel,面向 GPTQ/AWQ 风格的量化权重
5. kernel 启动与分发
invoke_fused_moe_triton_kernel(...)启动通用 kernelinvoke_fused_moe_wna16_triton_kernel(...)启动 WNA16 kerneldispatch_fused_moe_kernel(...)根据量化模式分发到具体启动函数
主调用路径
主调用链可以概括为:
inplace_fused_experts / outplace_fused_experts
-> fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> chunk loop
-> moe_kernel_quantize_input (for w1 input)
-> moe_align_block_size 或 naive assignment
-> dispatch_fused_moe_kernel (for w1)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> apply_moe_activation
-> moe_kernel_quantize_input (for w2 input)
-> dispatch_fused_moe_kernel (for w2)
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> moe_sum
如果走 WNA16 特化路径,则中间会改为:
dispatch_fused_moe_kernel
-> invoke_fused_moe_wna16_triton_kernel
-> fused_moe_kernel_gptq_awq
fused_experts_impl 的执行流程
fused_experts_impl(...) 是整个文件的调度核心。
1. 参数校验
它会检查:
hidden_states、w1、w2的 shape 是否匹配topk_weights和topk_ids形状是否一致hidden_states是否连续w1/w2最后一维 stride 是否为 1- 输入 dtype 是否属于:
torch.float32torch.float16torch.bfloat16
当前实现还显式限制:
activation == "silu"
虽然 MoEActivation 定义了多种激活,但当前入口只允许 SiLU 路径。
2. 配置选择
它先生成两个关键变量:
config_dtype用于 kernel 配置查表quant_dtype用于激活量化逻辑
之后用 try_get_optimal_moe_config(...) 选择当前 chunk 对应的 Triton 配置。
3. 中间缓存分配
会分配三块中间缓存:
intermediate_cache1保存w1输出,shape 为[M, top_k, N]intermediate_cache2保存激活后的结果,shape 为[M * top_k, activation_out_dim]intermediate_cache3保存w2输出,shape 为[M, top_k, K]
其中 cache13 复用了 cache1 和 cache3 的存储,因为两者生命周期不重叠。
4. 计算 dtype 选择
输入 dtype 与 kernel 中的 compute_type 对应关系:
torch.bfloat16 -> tl.bfloat16torch.float16 -> tl.float16torch.float32 -> tl.float32
5. 特殊量化格式处理
如果启用了 ocp_mx_scheme,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。
对于 use_int8_w8a16 和 use_int4_w4a16,当前公开入口中也会先把权重反量化成 hidden_states.dtype,随后把:
use_int8_w8a16 = Falseuse_int4_w4a16 = False
这意味着虽然文件内保留了 WNA16 特化 kernel,但从 fused_experts_impl(...) 这条入口实际运行时,常见路径仍然是通用 fused_moe_kernel(...)。
6. 分 chunk 执行
实现按 CHUNK_SIZE = 16 * 1024 分块处理 token,原因是:
- 限制中间缓存大小
- 让 kernel 配置更容易适配当前 chunk 规模
每个 chunk 内部执行:
- 量化第一层输入
- 把 routed token 按 expert 重排
- 跑
w1 - 激活
- 量化第二层输入
- 跑
w2 - 对 top-k expert 输出做求和
moe_align_block_size 的作用
MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时,expert 权重访问会非常离散。
moe_align_block_size(...) 的作用是:
- 把所有
(token, topk-slot)展平为 routed token - 按 expert 把 routed token 排序
- 对每个 expert 的 token 数量按
BLOCK_SIZE_M向上补齐
它输出三个核心量:
sorted_token_ids重排后的 routed token 下标expert_ids每个 M 方向 block 对应的 expert idnum_tokens_post_padded补齐后的 routed token 总数
这样后续 Triton kernel 就可以让一个 program 只处理:
- 一块 routed token
- 一个固定 expert
- 一块输出列
这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。
fused_moe_kernel 的核心原理
fused_moe_kernel(...) 是通用 kernel,也是阅读此文件时最值得重点关注的函数。
1. program 到 tile 的映射
kernel 先根据:
BLOCK_SIZE_MBLOCK_SIZE_NGROUP_SIZE_M
把 pid 映射到:
pid_mpid_n
这里使用 grouped ordering 来改善 L2 reuse。
2. 当前 block 对应哪些 token 和 expert
对每个 pid_m:
- 如果启用了重排,先从
sorted_token_ids取出当前 block 的 routed token - 从
expert_ids[pid_m]取出当前 block 对应的 expert
如果 expert_id == -1,则说明该 expert 不在当前 expert parallel rank,直接调用 write_zeros_to_output(...)。
3. 构造 A / B 指针
输入 A 的访问方式为:
offs_token // top_k
因为 routed token 是 (token, slot) 的展平形式,所以要除以 top_k 才能回到原始 token 下标。
权重 B 则按照:
- 当前 expert
- 当前 K block
- 当前 N block
来取对应子块。
4. 主循环
沿 K 维做 block 累加:
- 读取 A 子块
- 读取 B 子块
- 走普通浮点或量化分支
- 累加到
accumulator
accumulator 始终先用 fp32 保存,再在输出前转成 compute_type。
5. 后处理
主循环结束后依次执行:
- 可选 dequant
- 可选 bias
- 可选乘 router weight
- cast 到
compute_type - 写回 C
普通浮点与 FP8 路径
普通浮点
当:
use_fp8_w8a8 == Falseuse_int8_w8a8 == FalseA_scale is NoneB_scale is None
时,kernel 直接执行普通浮点路径:
accumulator += tl.dot(a, b)
这是 bfloat16、float16、float32 共享的主要逻辑。
FP8 W8A8
当启用 use_fp8_w8a8 时:
- Python 侧先调用
moe_kernel_quantize_input(...) moe_kernel_quantize_input(...)会进一步调用_fp8_quantize(...)- kernel 内根据量化粒度决定如何加载 scale
支持的粒度包括:
- tensor-wise
- per-channel
- block-wise
block-wise 路径下,会在每个 K block 内读取本 block 对应的:
a_scaleb_scale
然后执行近似:
C_block ~= dot(qA_block, qB_block) * a_scale * b_scale
bfloat16 支持
文件明确支持:
torch.bfloat16
其执行方式不是把整个 kernel 改成 BF16 累加,而是:
- 输入 / 输出使用 BF16
accumulator先用 FP32- 最后再 cast 回
tl.bfloat16
WNA16 路径
文件内还保留了一条 WNA16 特化路径:
invoke_fused_moe_wna16_triton_kernel(...)fused_moe_kernel_gptq_awq(...)
它主要用于:
int8_w8a16int4_w4a16
并支持:
- group scale
- zero point
- GPTQ/AWQ 风格权重布局
但需要注意:
- 当前
fused_experts_impl(...)中会先把 INT8/INT4 权重反量化成普通浮点 - 因此从当前公开入口来看,这条分支通常不会成为主路径
也就是说,代码结构上支持 WNA16 专用 kernel,但入口行为上更偏向统一走通用 kernel。
调用图
inplace_fused_experts
-> fused_experts_impl
outplace_fused_experts
-> fused_experts_impl
fused_experts_impl
-> _get_config_dtype_str
-> _get_config_quant_dtype
-> try_get_optimal_moe_config
-> get_moe_configs
-> _get_device_name
-> get_embedded_moe_configs
-> get_default_config
-> moe_kernel_quantize_input
-> _fp8_quantize / _int8_quantize
-> moe_align_block_size
-> dispatch_fused_moe_kernel
-> invoke_fused_moe_triton_kernel
-> fused_moe_kernel
-> write_zeros_to_output
-> invoke_fused_moe_wna16_triton_kernel
-> get_moe_wna16_block_config
-> _ensure_block_size_k_divisible
-> fused_moe_kernel_gptq_awq
-> write_zeros_to_output
-> apply_moe_activation
-> _silu_and_mul_kernel
-> moe_sum
当前实现的几个注意点
- 当前入口只允许
activation == "silu" - 明确支持输入 dtype:
float32float16bfloat16
- 通用主路径是
fused_moe_kernel(...) - 文件内虽然保留了 WNA16 特化 kernel,但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel
moe_align_block_size(...)是性能关键,它决定了 routed token 是否能被整理成规则 block
小结
fused_moe.py 的核心思想可以概括为三点:
- 把 routed token 按 expert 排序并按 block 对齐
- 用统一的 Triton kernel 两次完成 expert 的两层 GEMM
- 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑
从源码阅读角度,推荐优先关注以下函数:
fused_experts_impl(...)moe_kernel_quantize_input(...)dispatch_fused_moe_kernel(...)fused_moe_kernel(...)moe_align_block_size(...)
这几处连起来,基本就构成了整个 fused MoE 前向的主干。
