为什么MiniMax M2是一个Full Attention模型?
作为MiniMax M2预训练的负责人, 我收到了很多来自社区的询问:“为什么你们在MiniMax M2上开倒车, 采用了 Full Attention 机制?”在一次又一次的聊天中解释了背后的故事后, 我觉得是时候在一篇blog里写下我们的心路历程。
我可以花一整个下午来讨论为什么应该构建应该做 Linear/Sparse Attention。同样, 我也可以反过来花一整个下午来讨论为什么不应该去做。但所有这些纸上谈兵又有什么意义呢? 回到实际情况里, 我们要不要做呢?
先说结论: 我们一直在研究它。但在一个现实的工业系统中, Efficient Attention想要打败Full Attention还有些距离。LLM发展到今天, 整个链路越来越复杂, 服务的场景越来越来越多, 结构设计上需要考虑的点也爆炸式增长: Code/Math场景效果咋样、Agent场景下效果怎么样、多模态是什么表现、Long CoT行不行、RL能不能Scale上去、低精度运算下有没有坑、Interleaved Thinking怎么弄、线上怎么做cache、怎么做Speculative Decoding等等。
简而言之, 现实与理论存在着较大差异, 要为 Linear/Sparse Attention 正名, 就要在满足了条件1到条件n, 并解决了问题1到问题n之后。
为什么要做Efficient Attention?
如果你有无限算力, 你会选择研究Linear Attention或者Sparse Attention吗? 也许有人会说infinite context场景下, Softmax Attention的Attention Score会oversmoothing....但谁知道呢?在目前算力bound的前提下, 还没有哪个模型真的把Softmax Attention打到能力上限。因此, 从实际应用的角度来看, 目前大家做Linear Attention或者Sparse Attention都是奔着省算力去的。
有没有可能奔着省token去呢——达到同样效果, 需要的token量更少。如果你相信Scaling Law, 想达到这个目的, 应该不会选择走Efficient Attention这条路, 而是其他途径。
说穿了就是, 算力有限, 我们需要有一个更省算力的结构, 同算力消耗下 (training compute and inference compute) 有更好的效果。
需要解决的问题
我们希望能做一个能商用的模型, 我们必须从用户关心的问题出发: 效果、速度 (TPS) 和价格。效果是底线, 一个效果差的模型, 即便免费也毫无价值。那么, 怎么做一个效果足够好的Linear/Sparse/Hybrid Attention模型呢? 这里最大的问题不是结构设计, 而是评测的局限性。速度和价格受推理系统影响, 当然优秀的模型自然有优秀的工程师来帮忙优化。
1. 观测局限性
“只要你把bench建出来, 我自然能找到办法打上去。”
纵观大模型发展的几年, 榜单分数提升的速度是惊人的, 不管多难的榜, 就算榜单刚出来的时候SOTA只有个位数分数, 只要入了各家的法眼, 总能在几版迭代后刷爆。
怎么建一个全面、真能反应模型能力差异的评测链路? 这是一个很难的课题, 也是大模型迭代的重中之重。这个问题在模型结构, 特别是Attention迭代中, 将变得更加严峻。
Benchmark不够全面
“没有免费的午餐”, 把attention的复杂度降下去,付出的代价是什么?
在做MiniMax-Text-01的时候, 大家还普遍在看MMLU/BBH/Math/LongBench这类的榜单 (现在已经被刷爆了)。以一年多前的视角来看, Ligntning Attention + Full Attention完全能打全Full Attention, 毕竟这些榜单上都不差 (我们端到端训了个Hybrid架构的小模型来验证)。
难道真有免费午餐? 其实不然。这个代价在更大的模型上暴露出来了: 复杂多跳推理任务有明显缺陷。
当问题暴露出来就好办了, 对代理指标迭代优化! 经过几轮迭代, 代理指标看起来能打MHA了。但是scale上去后, 代理指标和下游实际场景还能match上吗? 还有其他的问题吗? 谁知道呢, 还没实验到这里。
模型越进步, 评测越难做。但这是必经之路, 加油评测人!
观测成本高
针对复杂多跳推理任务, 我们能找到训练更早期就可以被观测的代理指标, 但并不是所有任务在预训练阶段都可以被观测 (起码现在还不行)。同时随着任务越来越难, 想要到对应指标测的置信区间, 需要付出的算力代价也越来越大, 这也导致了实验迭代比较缓慢 (算力不够才来研究这玩意, 研究这玩意吃算力也不少)。
除了评测榜单, 还有模型优化问题, 不scale上去, 永远不知道会发生什么, 很多问题在小规模试验中无法暴露。看过M1论文的朋友应该会发现M1 RL训练过程中有严重的精度问题, 不做到这一步确实很难发现这个雷。基于这个现象再回过头去对Lightning Attention做数值收敛性分析, 要怎么解决真的很难通透了。发现问题, 真的比解决问题要难得的多。
其他变量
训练模型的变量太多太多, 不同结构在不同数据分布下的表现大相径庭, 不同结构适配的优化器也差异巨大。在数据高速迭代的周期里, 用一个月前的数据做实验可能会得完全相反的结论。
我们很难做完备的观测, 但在试图找更靠谱的实验策略。
2. Efficient Attention的基建
相比Full Attention, Linear Attention和Sparse Attention的基建要差的多, 想要真的拿到收益, 要补不少课。
以Linear Attention的基建问题为例: 如果对现在已有的线性结构做计算强度分析, 会发现不少结构都是访存bound (没错, 训练的时候)。如果不能做极致的IO优化, 从GPU算力利用来讲是吃亏的。
把视角再转到推理, 这里需要解决的问题比训练要多不少了: 如何提供一个真正更快更便宜的推理服务? Linear Attention的优势体现在线性计算强度, 常数显存占用。那么和Full Attention的计算消耗和显存消耗必然存在一个交叉点, 通常这个交叉点理论值在几K的大小, 对于今天的大模型, 这个长度并不长。但是注意, 这里是理论值, 我们需要解决下面几个问题来逼近这个数值:
- States的低精度存储: 当前Linear Attention对精度要求比Full Attention高得多;
- 如何解决Prefix Cache: 正常业务命中Cache的概念是很高的;
- 如何优化Linear Attention上的投机解码
幸好, 这些问题目前看起来都是可以解决的。
下一步是什么
Scaling这件事依旧是主旋律, Context Length是其中的关键之一, 不管是Pretrain还是PostTrain, Context Length增长的趋势越来越明显。当GPU算力的增速慢于Data长度增长带来的算力压力增加的时候, Linear/Sparse Attention的收益会被逐渐释放。我们需要提前准备些东西:
- 更多模态、信息更加丰富的长文数据
- 更合理的评测体系和实验范式, 帮助更快的迭代
- 更完善的训推基建, 榨干GPU的潜力
补充
开源推理代码swa的实现忘记删掉了, 看到有人问为什么最后没有用。这里也简单回复下: 当然是效果不行。
这个地方实验的比较早, 当时GPT-OSS还没有开源, 看到GPT-OSS结构长这样还挺吃惊的。这里可以简单讲讲我们的一些失败经验, 我们是CPT范式变成Hybrid SWA的思路。这里考虑了做层间混合和层内混合两种, 做层内混合的出发点是这样层间的计算强度是均衡的, 不管是训练做PP策略, 还是推理的时候做PP或者AFD分离都更友好。
当然都没work, 具体表现为Context越长性能下降越显著, 这在Agent场景是不太能接受的。
在我们的分析里, 这里有很多Global Attention Pattern (如retrieval head和induction head) 在前期预训练阶段已经形成, 通过CPT很难调整这些Attention Pattern。如果构建数据探针去检索对应的head并将其保留为Full Attention能极大的缓解对应问题, 但是不幸的是, 根据人类先验很难把这些Pattern全都探出来。
另外, 这个问题和Attention sink没有关系。
如果大家对这种思路感兴趣的话, GPT-OSS、CWM、Gemma的性能大家可以分析下, 特别是长文。
Intelligence with everyone
