FlashAttention-3 发布!有什么新优化点?

知乎热榜4个月前发布 NIUC!
502 0 0

方佳瑞的回答

Attention算子优化是大模型最重要Building Blocks,Flash Attention 也是我认为最成功的MLSys产品,它成功做到了把复杂留给自己,把简单交给用户。

看完技术报告,Flash Attention 3(FA3)在Hopper架构上的优化细致入微到warp-level了,当然作者列表里看到有NVIDIA高人指点,这也不觉意外。

FA3的推出标志着大模型的CUDA算子优化不能说进入深水区,而是进入了马里亚纳海沟了。很多人想打一个时间差,在H架构上闪击FA2搞大新闻了的路也被堵死了。

FA3优化有两个核心点:利用新硬件特性挖掘异步性和利用FP8低精度。

充分挖掘异步性:

核心在于利用Hopper新的硬件模块Tensor Memory Accelerator(TMA)

TMA允许在全局内存和共享内存之间进行高效的异步数据传输,减少了对寄存器的依赖。在TMA推出之前,数据从全局内存到共享内存的传输通常需要经过寄存器,这限制了数据传输效率,还增加了寄存器占用,并且多发送很多指令。用TMA类似于DMA方式在全局内存和共享内存之间异步拷贝,GPU可以把节省下来的指令cycle用来发射计算。

FA3把TMA用在GEMM里,通过异步性构成生产者和消费者方式来读和算。Q,K,V加载都用异步TMA读取相当于生产者,用Tensor Core WGMMA计算相当于消费者,这样充分重叠二者时间。作者把这个优化叫warp-specialization,specialization也有就是部分warp生产,部分warp消费,大家各自分工之意。

异步性也可以用在WarpGroup(4个连续的warp)粒度上来重叠GEMM和Softmax计算,前者用Tensor Core,后者用Multi-Function Unit,资源井水不犯河水,可以完全并行起来。为了精确控制计算的依赖关系,用了bar.sync指令。作者把这个优化叫ping-pong scheduling

在WarpGroup内部不同warp也可以重叠GEMM和Softmax的一些指令。

利用FP8低精度:

FA3之前时代,Attention都还是fb16算,FP8只能加速Linear层,尤其对于长序列FP8收效甚微。FA3让Attention计算也能用上FP8解决了大问题,利好FP8普及。

为了利用Tensor Core FP8能力,第一个优化点是给V矩阵做in-kernel transpose,来方便更好地适配计算的layout。

为了保持FP8量化精度,防止outliner,提出了两种技术。一个是block quatization,将Q,K,V分块,每块一个scaling factor。另一个是Incoherent Processing,这个比较有趣,它将Q,K分别乘一个随机正交的矩阵,这样每个Q,K outliner都减少了,而且不影响最终结果。这个技术值得大家关注一下,来自论文 QuIP: 2-Bit Quantization of Large Language Models With Guarantees

基于FA3很多下游配套的组件可以升级一下,我举两个例子:

  1. H20上适配,NV不会做H20,国人还需自强,计算显著降频之后异步性有什么影响值得研究。
  2. Ring Attention升级,a.k.a Megatron-LM的Context Attention。它依赖于FA,本质是FA分布式版本。
© 版权声明

相关文章

暂无评论

暂无评论...