-
Notifications
You must be signed in to change notification settings - Fork 735
[Feature] 线性注意力 (Lightning/Linear Attention) 循环状态接入 ForwardMeta.caches 缓存管理体系 #7199
Description
背景 / Background
ForwardMeta.caches 是 FastDeploy 推理引擎中统一管理注意力缓存的核心接口 [1]。它承载了 KV cache 的分配、调度器感知、block table 映射和生命周期管理。当前这套体系为标准 Transformer 的 KV cache([num_layers, 2, ...])量身设计 [1],能力完备。
但随着 线性注意力 (Linear Attention) 和 状态空间模型 (SSM) 架构进入生产推理(MiniMax-M1 [1]、RWKV*、Mamba/Jamba*、RetNet*、GLA* 等),一类新的缓存需求浮出水面:循环状态 (recurrent state)。这类状态的形状、语义和生命周期与 KV cache 存在根本差异。
现状与问题 / Current Limitation
以 MiniMax-M1 为例(PR #6994),其 70 层线性注意力各自维护一个形状为 [batch, heads, D, D] 的 _kv_history 张量 [1],作为跨 token 的循环累积状态。由于 ForwardMeta.caches 不支持此类形状 [1],当前实现只能将状态存储为 模型层实例属性 [1][2]:
# fastdeploy/model_executor/models/minimax_m1.py (L367-381)
# ⚠ 实例属性 — 绕过了调度器和内存管理
if not hasattr(self, "_kv_history") or self._kv_history is None or ...:
self._kv_history = paddle.zeros([batch_size, heads, D, D], dtype=q.dtype)
...
self._kv_history = new_kv_history这一模式可以工作,但存在以下系统性问题:
| 问题 | 影响 | 证据 |
|---|---|---|
| 多请求状态隔离缺失 | _kv_history 按 batch_size 重建,无法按 request 粒度管理不同请求的状态 |
[1][2] |
| 调度器不可见 | 调度器通过 ForwardMeta.caches / block table 追踪内存压力。实例属性上的隐式显存分配对调度器完全透明 |
[1] |
| 连续批处理 (continuous batching) 不兼容 | 请求动态加入/退出时,实例级状态无法跟随 slot 生命周期自动释放或迁移 | * |
| Prefix caching / 投机解码不可扩展 | 这些高级特性依赖 block table 抽象,无法适用于裸张量属性 | * |
这不是 MiniMax-M1 的孤立问题——任何引入循环状态的架构都会遇到相同瓶颈。
影响范围 / Scope
| 模型族 | 状态类型 | 形状 (per layer) | 备注 |
|---|---|---|---|
| MiniMax-M1 (Lightning Attention) [1] | kv_history | [B, H, D, D] |
PR #6994 — 已实现,使用实例属性 |
| RWKV-6/7 * | channel_state + time_state | [B, H, D, D] + [B, H, D] |
线性 RNN |
| Mamba / Jamba (S6) * | conv_state + ssm_state | [B, D_inner, D_conv] + [B, D_inner, D_state] |
状态空间模型 |
| RetNet * | kv_state | [B, H, D, D] |
类似 Lightning Attention |
| GLA (Gated Linear Attention) * | kv_state | [B, H, D, V] |
门控线性注意力 |
vLLM 在 v0.5+ 已引入 MambaCacheParams 及混合模型 (Jamba) 支持来解决同一类问题 [3],体现了社区对此需求的广泛共识。
期望结果 / Expected Outcome
在 ForwardMeta 中引入一个通用的 循环状态缓存接口,使:
- 循环状态可通过
ForwardMeta统一传递给模型层(与 KV cache 同级)[1][2] - 调度器能够感知循环状态的显存占用,纳入内存压力计算 [1]
- 状态生命周期与 slot / request 绑定,支持连续批处理下的动态加入/退出 *
- 为后续 prefix caching、投机解码等高级特性在循环状态上的扩展预留接口 *
具体设计方案可以后续讨论,本 issue 旨在明确需求和影响范围。
参考 / References
- PR [Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model support #6994 — MiniMax-M1 model support(包含 TODO 注释标记此问题)
- PR [Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model support #6994 AI Code Review — fastdeploy-bot 和 Copilot 在 4 轮独立代码审查中均将
_kv_history实例属性模式标记为多请求隔离风险,并建议迁移至ForwardMeta.caches - vLLM:
MambaCacheParams+ Jamba hybrid model support [3] — 解决同类问题的参考实现 - OpenNLPLab/lightning-attention — Lightning Attention 原始实现
证据标注 / Evidence key:
[1]已在 FastDeploy 源码中验证(forward_meta.pyL146:caches: Optional[list[paddle.Tensor]] = None # KV caches,minimax_m1.pyL367-381:_kv_history实例属性)[2]由 fastdeploy-bot AI 代码审查 和 GitHub Copilot 在 PR [Feature]【Hackathon 10th Spring No.47】Add MiniMax-M1 model support #6994 的 4 轮独立审查中均独立标记此问题:"多请求并发时可能导致状态污染"、"无法与调度器的 slot/block-table 缓存机制对齐"[3]已在 vLLM 源码中验证(jamba.py:mamba_cache、MambaStateDtypeCalculator、get_seqlen_agnostic_capture_inputs)*该模型族尚未在 FastDeploy 中实现——基于架构共性的前瞻性需求投射
我们在 PR #6994 中已用 TODO 标记了这个迁移点。如果社区对此方向有共识,我们愿意参与后续设计和实现。