FA3量化:Flash Attention 3激活量化算法说明¶
简介¶
- 背景:一方面,在长序列下,Attention 的中间激活 Q、K、V 张量在显存中占比高,对其进行量化将有效降低显存占用并提升计算效率;另一方面,Q、K、V 的激活动态范围大且分布高度不均,直接进行全局量化可能会导致精度损失严重。
- 核心思想:Flash Attention 3(FA3)是一种针对注意力机制激活的 per-head(逐注意力头)量化算法,对注意力机制中的 Q、K、V 激活进行 per-head 粒度的 INT8 量化,在保持模型精度的前提下提升推理性能和降低显存占用。FA3 量化通常与线性量化配合使用,以实现全量化方案。
使用前准备¶
安装 msModelSlim 工具,详情请参见《msModelSlim工具安装指南》。
原理和实现¶
原理¶
核心思想:
- 量化目标:对注意力机制中的 Q、K、V 激活值进行 per-head 粒度的 INT8 量化。
- 量化粒度:per-head(每个注意力头独立量化参数),适应不同 head 的激活分布差异。
- 量化时机:在 Multi-head Latent Attention (MLA) 计算的关键位置插入量化节点
- 量化策略:per-head:对每个注意力头独立计算量化参数。
算法流程:
-
收集每个 head 的激活统计数据:
- 输入:激活张量 x,shape 为
(B, H, S, D) - 其中 B=batch_size, H=num_heads, S=seq_len, D=head_dim
- 将 x reshape 为
(H, N),N = B * S * D - 每个 head 独立收集 N 个数据点
- 输入:激活张量 x,shape 为
-
对每个 head 使用 Recall Window 算法找到最小量化范围:
- 输入:head_data (N,), ratio (默认 0.9999)
- 对 N 个数据点进行排序:sorted_data = sort(head_data)
- 计算目标元素数量:target_num = int(ratio * N)
-
滑动窗口搜索最小范围:
-
遍历所有可能的窗口起点 i = 0 到 (N - target_num)
- 窗口范围:[sorted_data[i], sorted_data[i + target_num - 1]]
- 计算窗口长度:window_length = sorted_data[i + target_num - 1] - sorted_data[i]
-
保留窗口长度最小的窗口
-
输出:该 head 的 (min_val, max_val)
-
跨批次累积统计:
- 对每个校准批次,计算当前批次的 (min_val, max_val)
-
更新累积统计值,确保量化范围覆盖所有校准数据:
-
min_values[h] = min(min_values[h], current_min[h])
- max_values[h] = max(max_values[h], current_max[h])
-
计算每个 head 的量化参数:
-
对称量化公式:
-
abs_max[h] = max(abs(min_values[h]), abs(max_values[h]))
-
scale[h] = abs_max[h] / 127
-
输出:量化参数 q_param
-
实现¶
代码实现¶
- FA3 量化在 processor.py 中实现,处理流程分三阶段。
注入阶段¶
- 阶段:
preprocess。 - 调用模型适配器的
inject_fa3_placeholders()方法。 - 适配器负责在 MLA 计算流程中的关键位置插入占位器
FA3QuantPlaceHolder。 - 支持通过
include/exclude配置选择性注入。
校准阶段¶
- 阶段:
process。 - 占位符被替换为监听器
_FA3PerheadObserver。 - 校准数据流经注意力层时,监听器收集每个 head 的激活统计信息。
- 根据滑动窗口的思想找到包含指定比例数据的最小数值分布区间。
伪量化部署阶段¶
- 阶段:
postprocess。 - 从监听器提取每个 head 的 min/max 值。
- 调用
calculate_qparam()计算对称量化参数。 - 创建 IR 替换监听器。
适用要求¶
- 模型结构要求:
- 必须有支持 FA3 的模型适配器实现
FA3QuantAdapterInterface。 - 适用于基于 MLA 的注意力机制。
-
需要明确的 Q、K、V 激活计算路径以插入量化节点。
-
量化方式限制:
- 当前仅支持 INT8 对称量化。
- 量化参数在校准后固定,尚不支持动态调整。
功能介绍¶
模型支持¶
- DeepSeek-R1-0528
- DeepSeek-V3.1
YAML配置示例¶
作为Processor使用,YAML配置示例如下:
spec:
process:
- type: "fa3_quant"
include: [ "*" ] # 包含的注意力层
exclude: [ "model.layers.0.self_attn" ] # 排除的注意力层
YAML配置字段详解¶
| 字段名 | 作用 | 数据类型 | 默认值 | 说明 |
|---|---|---|---|---|
| type | 处理器类型标识 | string | - | 固定值"fa3_quant",用于标识该对象为FA3量化处理器。 |
| include | 包含的注意力层 | array[string] | ["*"] | 支持通配符匹配,指定要执行FA3量化的注意力层。 |
| exclude | 排除的注意力层 | array[string] | [] | 支持通配符匹配,优先级高于 include。 |
模型适配¶
接口与数据结构¶
目前已支持 DeepSeek-V3 系列模型,其他基于 MLA 的模型需要实现相应的适配器。
# 模型适配 FA3 量化接口
class ModelAdapter(FA3QuantAdapterInterface):
def inject_fa3_placeholders(
self,
root_name: str,
root_module: nn.Module,
should_inject: Callable[[str], bool],
) -> None: ...
适配步骤¶
-
前置要求:
-
模型基于 Transformer 架构,包含明确的注意力层。
- 注意力层的 Q、K、V 激活值在计算流程中可定位。
-
适配器能够访问模型的注意力模块并修改其 forward 方法。
-
步骤:
可参考 DeepSeek 的 model_adapter.py 的实现:
- 模型适配器继承
FA3QuantAdapterInterface接口。 - 遍历模型,通过
should_inject在注意力层中选择性注入占位器FA3QuantPlaceHolder作为子模块。 - 定位 Q、K、V 激活流向 Attention 计算的临界位置,该位置即为需要插入 FA3 量化的节点。
- 包裹注意力层的
forward方法,在定位到的临界位置插入对 FA3 量化的调用。