Files
FlagDoc/flaggems/fused_moe.md
2026-04-23 16:36:50 +08:00

18 KiB
Raw Permalink Blame History

fused_moe.py 分析

本文整理 src/flag_gems/fused/fused_moe.py 的实现结构、调用路径、数据类型分支和计算原理,方便从源码角度理解 FlagGems 的 fused MoE 前向。

第一章MoE 与 fused MoE 的数学原理与计算原理

1.0 MoE 在大模型中的背景

MoEMixture of Experts是大模型中一种常见的稀疏前馈结构。它的核心思想是

  • 不再让每个 token 都经过同一组固定的前馈层
  • 而是准备多个 expert 子网络
  • 再由 router 为每个 token 动态选择少量 expert 参与计算

在 Transformer 大模型中MoE 通常替代原来的 dense FFN 层。也就是说,在注意力层之后,原本是:

Attention -> Dense FFN

引入 MoE 后会变成:

Attention -> Router + Experts

这样做的目的,是在尽量不线性增加单 token 计算量的前提下,大幅增加模型总参数量和表示能力。

可以把它理解为:

  • dense FFN 是“所有 token 共用同一套前馈参数”
  • MoE 是“不同 token 只激活一小部分 expert 参数”

因此MoE 在大模型中的价值主要体现在:

  • 提升模型容量
  • 控制单 token 计算成本
  • 让不同 token 能使用更有针对性的 expert 子网络

从运行时角度看MoE 算子负责的就是这一层稀疏前馈计算。它的输入通常是 attention 之后的 hidden states输出则是经过 router 选择 expert、完成两层 expert MLP 并聚合后的新 hidden states。

更具体地说,这个算子主要做三件事:

  1. 根据 router 给出的 topk_idstopk_weights,确定每个 token 应该送到哪些 expert
  2. 对被选中的 expert 执行前馈计算
  3. 把多个 expert 的输出按路由权重聚合回 token 维度

因此从模型结构上看MoE 算子本质上是在实现:

token -> route to experts -> expert MLP -> weighted combine

fused_moe.py 做的不是重新定义这个算子,而是为这一层 MoE 前向提供一个更高性能的实现,重点优化:

  • token 到 expert 的重排
  • expert GEMM 的执行方式
  • 量化路径
  • 最终聚合过程

1.1 MoE 的数学原理

MoE Principle

结合上图,标准 MoE 层的计算过程可以形式化为“路由选择 -> 分发到专家 -> 专家前馈计算 -> 加权聚合输出”四个阶段。

设第 t 个 token 的隐藏状态表示为:

x_t \in \mathbb{R}^{d}

其中 d 为隐藏维度。MoE 层包含 E 个 expert记第 e 个 expert 对应的前馈函数为 $f_e$。

第一步是路由计算。Router 对输入 x_t 生成对所有 expert 的路由打分。最常见的 router 形式是一个线性映射:

s_t = W_g^{\top} x_t + b_g

其中 $W_g \in \mathbb{R}^{d \times E}$、$b_g \in \mathbb{R}^{E}$,因此 router 输出的打分向量满足:

s_t = \operatorname{Router}(x_t) \in \mathbb{R}^{E}

通常可进一步写成 softmax 后的路由概率:

p_t = \operatorname{softmax}(s_t), \qquad p_t \in \mathbb{R}^{E}

第二步是 Top-k 选择。Router 选择概率最高的 k 个 expert记其索引集合为 $S_t$

\mathcal{S}_t = \operatorname{TopK}(p_t, k)

对每个被选中的 expert $e \in S_t$,记其对应的路由权重为 $\alpha_{t,e}$。这一步对应图中的 Top-k Selection

第三步是 dispatch。根据集合 $S_t$token x_t 会被逻辑上分发到这些被选中的 expert未被选中的 expert 不参与当前 token 的计算。

第四步是 expert 计算。若每个 expert 采用图中所示的两层 MLP 结构 Linear1 -> Activation(SiLU) -> Linear2,则第 e 个 expert 的输出可以写为:

f_e(x) = W_{2,e}^{\top}\,\mathrm{SiLU}(W_{1,e}^{\top}x + b_{1,e}) + b_{2,e}

其中:

  • $W_{1,e}$、W_{2,e} 分别为第 e 个 expert 的两层线性变换参数
  • $b_{1,e}$、b_{2,e} 为对应偏置
  • \mathrm{SiLU}(\cdot) 为图中 expert 内部使用的激活函数

最后一步是聚合输出。对所有被选中的 expert 输出按对应路由权重进行加权求和,得到当前 token 的最终输出:

y_t = \sum_{e \in \mathcal{S}_t} \alpha_{t,e}\, f_e(x_t)

因此,上图中的完整计算链条可以概括为:

x_t
\xrightarrow{\text{Router}}
s_t
\xrightarrow{\text{Softmax}}
p_t
\xrightarrow{\text{TopK}}
\mathcal{S}_t
\xrightarrow{\text{Dispatch}}
\{f_e(x_t)\}_{e \in \mathcal{S}_t}
\xrightarrow{\text{Weighted Sum}}
y_t

也就是说MoE 的本质是:先由 router 从全部 expert 中选择一小部分参与当前 token 的计算,再将这些被选中 expert 的输出按路由权重加权求和,从而得到最终输出 $y_t$。

1.2 fused_moe 的数学原理

把每个 (token, topk-slot) 看作一个 routed item $r$,则:

  • 对应 token 为 t(r)
  • 对应 expert 为 e(r)
  • 对应路由权重为 a(r)

那么 fused_moe 实际仍然是在算同一个数学对象。

第一层:

h_r = W_{1,e(r)}^{\top}x_{t(r)} + b_{1,e(r)}

激活后:

z_r = \phi(h_r)

第二层:

o_r = W_{2,e(r)}^{\top}z_r + b_{2,e(r)}

最后对同一个 token 的 routed 输出求和:

y_t = \sum_{r: t(r)=t} a(r)\,o_r

代码中有一个开关 apply_router_weight_on_input,决定把 router weight 乘在第一层输出之后还是第二层输出之后。由于 a(r) 是标量,线性层满足:

W^{\top}(a z) = a\,W^{\top}z

因此两种写法数学上等价。

1.3 fused_moe 的计算原理

fused_moe 和普通 MoE 在数学上等价,主要差别在工程组织方式。

普通 MoE 更像是按逻辑步骤分开执行:

  1. router 选 top-k experts
  2. dispatch / gather token
  3. 对每个 expert 分别做 w1
  4. 激活
  5. 对每个 expert 分别做 w2
  6. 乘 router weight 并聚合

fused_moe 的核心思想是:

  1. 先把 routed token 按 expert 重排
  2. 把每个 expert 的 token 数量按 BLOCK_SIZE_M 补齐
  3. 让一个 Triton program 处理“同一 expert 的一块 token × 一块输出列”
  4. 用统一 kernel 两次完成 w1w2
  5. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑

从 routed token 视角看kernel 做的是:

C[r, n] = \sum_{k=1}^{K} A[t(r), k] \cdot B[e(r), n, k]

其中:

  • t(r) 是 routed item r 对应的原始 token
  • e(r) 是 routed item r 对应的 expert
  • n 是输出通道

在实现上,这套组织方式依赖三个关键步骤:

  • moe_align_block_size(...) 把 routed token 按 expert 排列成规则 block
  • fused_moe_kernel(...) 按 block 执行 expert GEMM
  • moe_sum(...) 对每个 token 的 top-k 输出做最终求和

1.4 和普通 MoE 实现的区别

从数学上看,fused_moe 和普通 MoE 没有本质差别,仍然是在做:

y_t = \sum_{i=1}^{k} \alpha_{t,i}\, f_{e_{t,i}}(x_t)

差别主要在工程组织方式:

  • 普通 MoE 更像是按逻辑步骤分开执行
    • dispatch
    • expert matmul
    • activation
    • expert matmul
    • combine
  • fused MoE 则更强调:
    • 按 expert 重排 token
    • 让一个 Triton program 处理规则 tile
    • 把量化、bias、router weight 等附加逻辑尽量揉进 kernel

这样做的收益是:

  • 减少 kernel launch
  • 改善权重访问局部性
  • 提高大规模 expert 场景下的吞吐

文件定位

fused_moe.py 是一个面向 Triton 的 MoE 前向实现文件,核心职责包括:

  • 选择或生成 MoE kernel 配置
  • 对输入激活做量化预处理
  • 启动 Triton MoE kernel
  • 执行两层 expert MLP
  • 在 top-k expert 输出之间做聚合

它依赖两个关键辅助模块:

  • flag_gems.fused.moe_align_block_size 负责把 routed token 按 expert 重排并按 block 对齐
  • flag_gems.fused.moe_sum 负责把每个 token 的 top-k expert 输出求和

对外入口

对外可直接调用的入口有两个:

  • inplace_fused_experts(...)
  • outplace_fused_experts(...)

两者都只是薄封装,最终都调用:

  • fused_experts_impl(...)

其中:

  • inplace_fused_experts 把结果直接写回 hidden_states
  • outplace_fused_experts 返回新分配的输出张量

文件内主要函数分组

1. 配置选择

  • get_embedded_moe_configs()fused_moe_config.yaml 读取内嵌调优配置
  • _get_device_name() 获取并规范化设备名
  • get_moe_configs(...) 按设备名、expert 数、输出维度、dtype、block shape 查询配置
  • try_get_optimal_moe_config(...) 优先使用内嵌配置,否则回退到启发式配置
  • get_default_config(...) 默认配置生成逻辑
  • get_moe_wna16_block_config(...) 针对 WNA16 路径生成 block 配置
  • _ensure_block_size_k_divisible(...) 修正 BLOCK_SIZE_K

2. 数据类型与量化辅助

  • _get_config_dtype_str(...) 把当前 dtype 和量化模式映射成配置查表 key
  • _get_config_quant_dtype(...) 把量化模式映射成激活量化 dtype
  • _fp8_quantize(...) 激活 FP8 量化
  • _int8_quantize(...) 激活 INT8 量化
  • moe_kernel_quantize_input(...) 在进入 GEMM 前对输入激活做量化
  • dequant_mxfp4(...)
  • dequant_mxfp6(...)

3. 激活函数

  • MoEActivation 定义支持的激活枚举
  • apply_moe_activation(...) 执行激活
  • _silu_and_mul_kernel(...) Triton 版 SiLU-and-mul

4. Triton kernel

  • write_zeros_to_output(...) expert 不在当前 rank 时写零
  • fused_moe_kernel(...) 通用 MoE kernel覆盖普通浮点、FP8、INT8 路径
  • fused_moe_kernel_gptq_awq(...) WNA16 专用 kernel面向 GPTQ/AWQ 风格的量化权重

5. kernel 启动与分发

  • invoke_fused_moe_triton_kernel(...) 启动通用 kernel
  • invoke_fused_moe_wna16_triton_kernel(...) 启动 WNA16 kernel
  • dispatch_fused_moe_kernel(...) 根据量化模式分发到具体启动函数

主调用路径

主调用链可以概括为:

inplace_fused_experts / outplace_fused_experts
  -> fused_experts_impl
     -> _get_config_dtype_str
     -> _get_config_quant_dtype
     -> try_get_optimal_moe_config
     -> chunk loop
        -> moe_kernel_quantize_input (for w1 input)
        -> moe_align_block_size 或 naive assignment
        -> dispatch_fused_moe_kernel (for w1)
           -> invoke_fused_moe_triton_kernel
           -> fused_moe_kernel
        -> apply_moe_activation
        -> moe_kernel_quantize_input (for w2 input)
        -> dispatch_fused_moe_kernel (for w2)
           -> invoke_fused_moe_triton_kernel
           -> fused_moe_kernel
        -> moe_sum

如果走 WNA16 特化路径,则中间会改为:

dispatch_fused_moe_kernel
  -> invoke_fused_moe_wna16_triton_kernel
  -> fused_moe_kernel_gptq_awq

fused_experts_impl 的执行流程

fused_experts_impl(...) 是整个文件的调度核心。

1. 参数校验

它会检查:

  • hidden_statesw1w2 的 shape 是否匹配
  • topk_weightstopk_ids 形状是否一致
  • hidden_states 是否连续
  • w1/w2 最后一维 stride 是否为 1
  • 输入 dtype 是否属于:
    • torch.float32
    • torch.float16
    • torch.bfloat16

当前实现还显式限制:

  • activation == "silu"

虽然 MoEActivation 定义了多种激活,但当前入口只允许 SiLU 路径。

2. 配置选择

它先生成两个关键变量:

  • config_dtype 用于 kernel 配置查表
  • quant_dtype 用于激活量化逻辑

之后用 try_get_optimal_moe_config(...) 选择当前 chunk 对应的 Triton 配置。

3. 中间缓存分配

会分配三块中间缓存:

  • intermediate_cache1 保存 w1 输出shape 为 [M, top_k, N]
  • intermediate_cache2 保存激活后的结果shape 为 [M * top_k, activation_out_dim]
  • intermediate_cache3 保存 w2 输出shape 为 [M, top_k, K]

其中 cache13 复用了 cache1cache3 的存储,因为两者生命周期不重叠。

4. 计算 dtype 选择

输入 dtype 与 kernel 中的 compute_type 对应关系:

  • torch.bfloat16 -> tl.bfloat16
  • torch.float16 -> tl.float16
  • torch.float32 -> tl.float32

5. 特殊量化格式处理

如果启用了 ocp_mx_scheme,会先把 MX 权重反量化成普通浮点权重,再走后续通用路径。

对于 use_int8_w8a16use_int4_w4a16,当前公开入口中也会先把权重反量化成 hidden_states.dtype,随后把:

  • use_int8_w8a16 = False
  • use_int4_w4a16 = False

这意味着虽然文件内保留了 WNA16 特化 kernel但从 fused_experts_impl(...) 这条入口实际运行时,常见路径仍然是通用 fused_moe_kernel(...)

6. 分 chunk 执行

实现按 CHUNK_SIZE = 16 * 1024 分块处理 token原因是

  • 限制中间缓存大小
  • 让 kernel 配置更容易适配当前 chunk 规模

每个 chunk 内部执行:

  1. 量化第一层输入
  2. 把 routed token 按 expert 重排
  3. w1
  4. 激活
  5. 量化第二层输入
  6. w2
  7. 对 top-k expert 输出做求和

moe_align_block_size 的作用

MoE 的普通数学形式里,每个 token 会被路由到若干 expert。直接按 token 顺序计算时expert 权重访问会非常离散。

moe_align_block_size(...) 的作用是:

  • 把所有 (token, topk-slot) 展平为 routed token
  • 按 expert 把 routed token 排序
  • 对每个 expert 的 token 数量按 BLOCK_SIZE_M 向上补齐

它输出三个核心量:

  • sorted_token_ids 重排后的 routed token 下标
  • expert_ids 每个 M 方向 block 对应的 expert id
  • num_tokens_post_padded 补齐后的 routed token 总数

这样后续 Triton kernel 就可以让一个 program 只处理:

  • 一块 routed token
  • 一个固定 expert
  • 一块输出列

这就是 fused MoE 能把很多零散小计算组织成规则 block GEMM 的关键。

fused_moe_kernel 的核心原理

fused_moe_kernel(...) 是通用 kernel也是阅读此文件时最值得重点关注的函数。

1. program 到 tile 的映射

kernel 先根据:

  • BLOCK_SIZE_M
  • BLOCK_SIZE_N
  • GROUP_SIZE_M

pid 映射到:

  • pid_m
  • pid_n

这里使用 grouped ordering 来改善 L2 reuse。

2. 当前 block 对应哪些 token 和 expert

对每个 pid_m

  • 如果启用了重排,先从 sorted_token_ids 取出当前 block 的 routed token
  • expert_ids[pid_m] 取出当前 block 对应的 expert

如果 expert_id == -1,则说明该 expert 不在当前 expert parallel rank直接调用 write_zeros_to_output(...)

3. 构造 A / B 指针

输入 A 的访问方式为:

offs_token // top_k

因为 routed token 是 (token, slot) 的展平形式,所以要除以 top_k 才能回到原始 token 下标。

权重 B 则按照:

  • 当前 expert
  • 当前 K block
  • 当前 N block

来取对应子块。

4. 主循环

沿 K 维做 block 累加:

  • 读取 A 子块
  • 读取 B 子块
  • 走普通浮点或量化分支
  • 累加到 accumulator

accumulator 始终先用 fp32 保存,再在输出前转成 compute_type

5. 后处理

主循环结束后依次执行:

  • 可选 dequant
  • 可选 bias
  • 可选乘 router weight
  • cast 到 compute_type
  • 写回 C

普通浮点与 FP8 路径

普通浮点

当:

  • use_fp8_w8a8 == False
  • use_int8_w8a8 == False
  • A_scale is None
  • B_scale is None

kernel 直接执行普通浮点路径:

accumulator += tl.dot(a, b)

这是 bfloat16float16float32 共享的主要逻辑。

FP8 W8A8

当启用 use_fp8_w8a8 时:

  • Python 侧先调用 moe_kernel_quantize_input(...)
  • moe_kernel_quantize_input(...) 会进一步调用 _fp8_quantize(...)
  • kernel 内根据量化粒度决定如何加载 scale

支持的粒度包括:

  • tensor-wise
  • per-channel
  • block-wise

block-wise 路径下,会在每个 K block 内读取本 block 对应的:

  • a_scale
  • b_scale

然后执行近似:

C_block ~= dot(qA_block, qB_block) * a_scale * b_scale

bfloat16 支持

文件明确支持:

  • torch.bfloat16

其执行方式不是把整个 kernel 改成 BF16 累加,而是:

  • 输入 / 输出使用 BF16
  • accumulator 先用 FP32
  • 最后再 cast 回 tl.bfloat16

WNA16 路径

文件内还保留了一条 WNA16 特化路径:

  • invoke_fused_moe_wna16_triton_kernel(...)
  • fused_moe_kernel_gptq_awq(...)

它主要用于:

  • int8_w8a16
  • int4_w4a16

并支持:

  • group scale
  • zero point
  • GPTQ/AWQ 风格权重布局

但需要注意:

  • 当前 fused_experts_impl(...) 中会先把 INT8/INT4 权重反量化成普通浮点
  • 因此从当前公开入口来看,这条分支通常不会成为主路径

也就是说,代码结构上支持 WNA16 专用 kernel但入口行为上更偏向统一走通用 kernel。

调用图

inplace_fused_experts
  -> fused_experts_impl

outplace_fused_experts
  -> fused_experts_impl

fused_experts_impl
  -> _get_config_dtype_str
  -> _get_config_quant_dtype
  -> try_get_optimal_moe_config
     -> get_moe_configs
        -> _get_device_name
           -> get_embedded_moe_configs
     -> get_default_config
  -> moe_kernel_quantize_input
     -> _fp8_quantize / _int8_quantize
  -> moe_align_block_size
  -> dispatch_fused_moe_kernel
     -> invoke_fused_moe_triton_kernel
        -> fused_moe_kernel
           -> write_zeros_to_output
     -> invoke_fused_moe_wna16_triton_kernel
        -> get_moe_wna16_block_config
           -> _ensure_block_size_k_divisible
        -> fused_moe_kernel_gptq_awq
           -> write_zeros_to_output
  -> apply_moe_activation
     -> _silu_and_mul_kernel
  -> moe_sum

当前实现的几个注意点

  • 当前入口只允许 activation == "silu"
  • 明确支持输入 dtype
    • float32
    • float16
    • bfloat16
  • 通用主路径是 fused_moe_kernel(...)
  • 文件内虽然保留了 WNA16 特化 kernel但当前公开入口通常会先把 INT8/INT4 权重反量化,实际常走通用 kernel
  • moe_align_block_size(...) 是性能关键,它决定了 routed token 是否能被整理成规则 block

小结

fused_moe.py 的核心思想可以概括为三点:

  1. 把 routed token 按 expert 排序并按 block 对齐
  2. 用统一的 Triton kernel 两次完成 expert 的两层 GEMM
  3. 在 kernel 内尽量融合量化、bias、router weight 等附加逻辑

从源码阅读角度,推荐优先关注以下函数:

  • fused_experts_impl(...)
  • moe_kernel_quantize_input(...)
  • dispatch_fused_moe_kernel(...)
  • fused_moe_kernel(...)
  • moe_align_block_size(...)

这几处连起来,基本就构成了整个 fused MoE 前向的主干。