探秘Transformer系列之(18)--- FlashAttention
目录
- 0x00 概述
- 0.1 问题
- 0.2 其它解决方案
- 0.3 Flash Attention
- 0x01 背景知识
- 1.1 GPU相关概念
- 硬件概念
- 软件概念
- 运行模式
- 线程模型
- Grid & Device
- Block & SM
- Thread & SP
- Thread & Warp
- 总结
- 1.2 Transformer的内存和计算
- 基本概念
- 计算受限与内存受限
- 注意力机制的计算强度
- 如何平衡
- 1.3 Tiling
- 1.4 算子融合
- 0x02 优化注意力机制
- 0x03 Softmax改进
- 3.1 原生softmax
- 3.2 历程
- 3.3 3-Pass Safe Softmax
- 3.4 online softmax 2-pass
- 3.5 Multi-pass Self-Attention
- 动机
- Multi-pass Self-Attention算法
- 引入到FlashAttention
- 3.6 1-pass FlashAttention
- 3.7 Algorithm FlashAttention (Tiling)
- 3.8 小结
- 0x04 FlashAttention V1
- 4.1 总体思路
- 4.2 算法
- 4.3 证明
- 定义
- 推导
- 常规softmax
- safe softmax
- 结合O来分析
- 4.4 分块
- 4.5 流程
- 前置条件
- 第一步
- 第二步
- 第三步
- 第四步
- 循环计算
- 第五步
- 第六步
- 第七步
- 第八步
- 循环内计算
- 第九步
- 第十步
- 第十一步
- 第十二步
- 第十三步
- 第十四、十五、十六步
- 总结
- 0x05 计算量与显存占用
- 5.1 IO复杂度
- 标准注意力
- FlashAttention
- 反向传播
- 5.2 计算复杂度
- 0xFF 参考
0x00 概述
0.1 问题
Transformer架构的核心是自注意力机制这个强大的组件。然而,自注意力机制的执行速度很慢,并且内存占用很大,特别是在处理长上下文长度时。对于Transformer模型,假设其输入序列长度为N,则其Transformer模型的计算复杂度和空间复杂度都是\(O(N^2)\),即模型的计算量和存储空间随着序列长度N呈二次方增长。当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这限制了大语言模型的最大序列长度N的大小,这就是在发展初期,大模型往往只支持2K或4K token输入的原因。所以人们寻求降低Transformer模型的\(O(N^2)\)复杂度,争取让复杂度逼近\(O(N)\)或者降到\(O(N)\)。
0.2 其它解决方案
在FlashAttention之前,人们已经做了很多尝试,基本上有两条路径:降低注意力机制的计算复杂度和降低注意力机制的空间复杂度。通常将由这些方法改进得到的模型称为Efficient Transformer。
在计算复杂度方面,一些工作尝试提出近似的注意力机制算法,来降低 attention 的理论上的计算复杂度。主要可以分为稀疏 (sparse) 估计、低秩 (low-rank) 估计等。其中,稀疏估计的基本思想是通过一个稀疏的矩阵来近似完整的、稠密 (dense) 的注意力矩阵,比如,Reformer]对 Q 和 K 进行局部敏感哈希(Local Sensitive Hashing),只对同一个 桶 (bucket) 中的 Q 和 V 计算 attention,将 attention 的时间复杂度从 $O(n^2) $降低到 \(o(nlog(n))\) 。再比如,低秩近似的基本思想通过一个低秩 (low-rank) 矩阵来估计注意力矩阵,比如,linear transformer引入核函数 \(\phi(x)\) ,将 \(score=softmax(QK^T)V\) 形式化成 \(score=\phi(Q)(\phi(K)^TV)\) ,来解耦开 softmax 运算中的 Q 和 K 。这样操作之后,可以先计算\(score=\phi(Q)(\phi(K)^TV)\),该运算的时间复杂度为 \(O(n)\)。虽然降低注意力机制的计算复杂度在理论上非常具有吸引力,但是在实际应用中仍然存在一些短板,比如以下两点:
- 性能比不上原始注意力机制。不论是稀疏估计、低秩估计还是其他,这些方法都采用了某种近似算法来估算注意力权重矩阵,难免会丢失信息。目前主流的还是原始的注意力机制;
- 无法减少内存读取的时间消耗。这些方法只能降低注意力机制的计算复杂度,但是无法对注意力机制的运算过程中的空间复杂度等进行控制,无法减少内存读写带来的时间损耗。
在空间复杂度方面,这方面工作的基本思路是降低注意力机制对于显存的需求,减少 HBM 和 SRAM 之间的换入换出,进而减少注意力机制运算的时间消耗。一种具有代表性的方法是 kernel fusion,其思想很简单,即将需要通过多个 CUDA kernel 来分步完成的操作融合到一个或者少数几个 CUDA kernel,从而减少数据在HBM和SRAM之间换入换出的次数,进而节省运算时间。
0.3 Flash Attention
FlashAttention的作者们发现,这些Efficient Transformer虽然能够有效降低模型的FLOPS,但它们的计算速度并没有显著降低。导致该现象的根本原因是大多数Efficient Transformer通常只关注FLOPS(Floating Point Operations Per Second),该指标是计算密集型应用程序和深度学习模型性能的常用指标。然而,模型的计算速度除了与FLOPS有很大关系,同时也与MAC(Memory Access Cost,存储访问开销)有关。尤其是当计算本身已经很高效的情况下,MAC的开销更加不能忽略。MAC的开销主要来自两方面。一是从存储中读取数据;二是向存储中写数据。与CPU的情况类似,在GPU中,当需要计算时,需将数据从显存中读取并由计算单元进行计算操作。在计算完毕后,再写回到显存中。
Flash Attention所作的工作体现在其论文题目“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”中,具体如下:
- Fast(with IO-Awareness),计算快。Flash Attention之前加速Transformer计算方法的着眼点在于“减少计算量FLOPs”,比如用稀疏Attention来近似计算。但Flash Attention作者发现计算慢的瓶颈是IO读写速度而非计算能力,因此Flash Attention通过减少访问显存(HBM)的次数来提高整体运算速度,这就是IO感知(with IO-Awareness)。具体而言,减少访问显存(HBM)的次数是通过分块计算(tiling)和核函数融合(kernel fusion)技术来实现的。
- Memory Efficicent,节省显存。在标准Attention场景中,前向传播时会保存\(N^2\)大小的注意力矩阵\(P,S\),反向传播时又会读取注意力矩阵来计算梯度,这就是显存复杂度为\(O(N^2)\)的原因。Flash Attention通过引入统计量来改变注意力机制的计算顺序,避免了实例化注意力矩阵,从而使得存储压力降至 O(N) 。
- Exact Attention,精准注意力,计算结果完全相同。Flash Attention之前的“稀疏Attention”属于近似计算,虽然可以减少计算量,但是其计算结果与标准Attention计算结果不同。Flash Attention的计算结果与标准Attention计算结完全相同。
简单来说,注意力公式为:\(Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\),FlashAttention不需要在全局内存上实现 中间矩阵,而是将上述公式中的整个计算融合到单个 CUDA 内核中,这样,我们就不需要大量的I/O。另外,对于矩阵乘法等经典算法,还会使用平铺(tiling)来确保片上内存不超过硬件限制。
0x01 背景知识
因为大模型主要是在GPU上进行训练和推理,所以我们首先看看GPU相关知识,然后看看Transformer的计算特点。
1.1 GPU相关概念
我们在学习和使用CUDA时候,经常见到很多概念,比如SM,SP,Grid等,通常令人感到疑惑。接下来就带领大家做简要的解读。这些概念通常分为两类:
- 硬件资源或者概念,包括:SP,SM,HBM和SRAM;
- 软件抽象或者概念,包括:Thread、Warp、Block和Grid;
硬件概念
首先来看看一些硬件概念。
运行单元
这里主要包括SM(Streaming Multiprocessors,流式多处理器)和SP(Streaming Processor,计算单元)概念。GPU由一系列SM组成。SM是GPU的基本计算单元,其好比多核的CPU芯片里面的一个核。不同之处在于,CPU的一个核一般是运行一个线程,而SM能够运行多个轻量线程。每个SM都拥有一定数量的寄存器、片上内存(on-chip memory)、控制单元和若干SP或其他加速计算单元。这些片上内存和控制单元被所有的SP共享。此外,每个SM都配备了基于硬件的线程调度器,用于执行线程。
内存
我们用A100-40GB为例来揭示GPU的内存状况。下面是A100-40GB的内存层级结构图。
上面是三层金字塔,最下面是CPU上的内存,量大,但是很慢。上面两层则属于GPU,GPU的内存由多个不同大小和不同读写速度的内存组成,可以按照是否在芯片上分为片上内存和片下内存(off chip),在NVIDIA A100-40GB卡上两种内存的信息如下。
类型名称作用大小读写速度特点片上内存SRAM(Static Random-Access Memory)主要用于缓存(cache)及少量特殊存储单元(例如texture)分布在108个流式多处理器上,每个处理器大小为192K。合计为 192∗108KB=20,736KM=20MB19TB/s存储空间小,带宽大片下内存HBM(High Bandwidth Memory)主要用于全局存储(global memory),即我们常说的显存40~80GB1.5~2.0TB/s存储空间大,带宽小这里要再次强调一点:SRAM 是 L1 Cache(组合共享内存和数据缓存)。
可以看到,显存的带宽相比SRAM要小的多,读一次数据很费时,但是SRAM存储又太小,装不下太多数据。所以我们就以SRAM的存储为上限,尽量保证每次加载数据都把SRAM给打满,节省数据读取时间。
软件概念
运行模式
一个CUDA程序可以分为两个部分(两者拥有各自的存储器):
- 在CPU上运行的称为Host程序,或者可以把CPU理解为Host。
- 在GPU上运行的称为Device程序,又被叫做Kernel函数。或者可以把GPU理解为Device。
对应的GPU执行操作的典型方式分为以下几步:
- CPU把计算指令传送给GPU;
- 把数据从CPU的内存拷贝到GPU的内存,即HBM;
- GPU将输入数据从低速的HBM中加载到高速的SRAM中;
- GPU把计算任务分配到各个SM并行处理;
- SM从SRAM读取数据进行计算操作;
- 计算完毕后将计算结果从SRAM写到HBM里;
- 计算结果再从HBM拷贝到CPU内存;
线程模型
在GPU上需要启用多个线程来执行kernel。比如在向量相加的示例中,如果我们要对256维的向量进行相加运算,那么可以使用256个线程并行处理,这样每个线程就可以处理向量的一个元素。如果数据更大,GPU上也许没有足够的线程可用,这时我们可能需要每个线程能够处理多个数据点。因此需要程序员依据数据的大小和我们所需的并行度来仔细配置线程。
为了方便程序员设计、组织线程,在CUDA编程上把软件资源抽象成为一个线程模型,该模型包括Grid、Block、Thread和Warp等概念,每个概念对应的软件抽象和硬件资源对应如下。
- Thread:并行执行的基本单元。一个CUDA并行程序由多个thread来执行,thread是最基本的执行单元(the basic unit of execution)。Thread的执行由SP来完成。一个SP可以执行一个thread。
- Block:数个threads组成一个block。一个block占用一个SM运行。
- Grid(线程网格):多个blocks则会再构成Grid。一个Kernel函数对应一个Grid。Grid运行在device之上。
- Warp: 执行程序时的调度单位,32/16个threads组成一个warp。每个warp中的thread可以同时执行相同的指令,从而实现SIMT(单指令多线程)并行。warp是SM中最小的调度单位(the smallest scheduling unit on an SM),一个SM可以同时处理多个warp;
Grid、Block、Thread是线程组织的三个层次,是一种软件架构,和硬件无关。因此理论上我们可以以任意的维度(一维、二维、三维)去排列Grid,Block,Thread。这个软件架构落实到硬件上就分别对应一个个的SM或者SP。硬件并没有维度这一说,只是软件上抽象成了具有维度的概念。具体如下图所示。
这些软件概念和硬件资源的具体解释如下,我们这次按照从上到下的层级来进行介绍。
Grid & Device
Grid的作用是线程数量控制和差异性执行。CUDA让Host程序里的一个个Kernel函数按照Grid的概念在device上执行。一个Kernel函数对应一个Grid;Grid跑在device上的时候,可能是独占一个device,也可能是多个kernel并发占用一个 device;
Block & SM
Block是线程块,同一个block中的threads可以同步,也可以通过共享内存来加速通信。每个Grid承接了一个kernel函数的任务。当执行任务时,每一个Grid又把任务分成若干Block(线程块)在SM上运行。Grid和SM的关系是:
- 同一 Grid 下的不同 Block 可能会被分发到不同的 SM 上执行。一个Block的thread只能在一个SM上调度,即Block不能跨SM。
- SM上可以同时执行多个Block,这些Block不一定来自同一个kernel函数。有时候即便SM上剩余资源不足以再容纳一个kernel A的Block,但却仍可能容纳下一个kernel B的Block。多个block需要轮流进入SM。
- 每个线程会占用一定数量的寄存器和Shared Memory,因此SM上同时存活的Block数目不应当超过这些硬件资源的限制。
- 一个thread block可以包含多个warp,同一个block中的thread可以同步,也可以通过shared memory进行通信。thread block是GPU执行的最小单位(the smallest unit of execution on the GPU)。一个warp中的threads必然在同一个block中,如果block所含thread数量不是warp大小的整数倍,那么多出的那个warp中会剩余一些inactive的thread。也就是说,即使warp的thread数量不足,硬件也会为warp凑足thread,只不过这些thread是inactive状态,但也会消耗SM资源。
Thread & SP
一个CUDA的程序(即kernel的任务)最终被拆分到线程来完成。每个Thread中的局域变量被映射到SM的寄存器上,而Thread的执行则由CUDA Core也就是SP来完成。
Thread & Warp
因为Block的大小不定,所以我们实际上无法对一个任意大小的Block都给出一个同等大小的CUDA核心阵列去并行计算。为了更好的管理和执行Thread,GPU采用了SIMT(Single Instruction Multiple Threads)架构,提出了Wrap(线程束)概念。我们首先看看SIMT和SIMD的区别。
- CPU中通过SIMD来处理矢量数据。纯粹使用SIMD不能并行的执行有条件跳转的函数,很显然条件跳转会根据输入数据不同在不同的线程中有不同表现。
- GPU则使用SIMT来处理数据。无需开发者费力把数据凑成合适的矢量长度,并且SIMT允许每个线程有不同的分支,利用SIMT 才能做到不同分支的并行操作。
Wrap是GPU编程架构中的最小调度/执行单元,同一个Warp里的线程执行相同的指令,即SIMT。Block被划分成一块块的warp分别映射到CUDA核心阵列上执行,每一个warp就都可以理解为是一个线程的集装箱,为的是线程数量固定统一可以给他分配统一的硬件资源,每个集装箱只装一种货物,也就是同步执行的意思。一般为32个线程为一个warp,它们在同一个时钟周期内并行执行相同的指令,实现了单指令、多线程。每个线程能够访问自己的寄存器,不同的warp在计算时会从SRAM中读取计算所需的数据(即共享存储寄存器),即不同的Warp从不同的地址加载和存储,并遵循不同的控制流路径。
总结
现在,我们将GPU的计算核心SM及不同层级GPU存储结构综合起来,绘制一张简化图。
- 寄存器:GPU中的每个SM都拥有大量寄存器。这些寄存器在核心之间共享,并根据线程需求动态分配。在执行过程中,每个线程都被分配了私有寄存器,其他线程无法读取或写入这些寄存器。
- L1缓存/shared memory:每个SM都有自己的L1缓存,用于存储SM内的数据,被SM内所有的cuda cores共享。SM间不能互相访问彼此的L1。NV Volta架构后(Volta架构前只有Kepler做过合并),L1和shared memory合并,目的是为了进一步降低延迟。合并过后,用户能写代码直接控制的依然是shared memory,同时可控制从L1中分配多少存储给shared memory。其中FlashAttention中SRAM指的就是L1 cache/shared memory。
- L2缓存:所有SM共享L2缓存。L1/L2缓存的带宽都要比显存的带宽要大,也就是读写速度更快,但是它们的存储量更小。
- HBM:即显存。
1.2 Transformer的内存和计算
从计算科学角度来看,操作的性能瓶颈有两类:计算受限(Compute-bound或者math-bound)和内存受限(Bandwidth-bound或者Memory-bound)。而想降低Transformer模型的计算复杂度和空间复杂度,就需要找出Transformer核心组件注意力机制的资源瓶颈究竟是计算能力还是显存,这样我们就可以知道应该在哪个方面进行优化。
基本概念
我们接下来从基本概念入手进行分析。
- 计算带宽(math bandwidth)\(\pi\)。此概念可以理解为算力,具体指的是处理器每秒钟可以执行的数学计算次数,单位通常是OPS(operations/second)。如果用浮点数进行计算,则单位是FLOPS(每秒执行的浮点数操作次数)。
- 内存带宽 (memory bandwidth)\(\beta\)。此概念指的是处理器每秒钟从内存中读取的数据量,单位是bytes/second。
- 计算强度 (arithmetic intensity) \(I = \frac{N_{op}}{N_{byte}}\)。此概念指的是算法对于内存带宽的需求,即在此算法中,平均每读入单位数据(IO)能支持多少次浮点运算操作(FLOP)。它可以通过将FLOPs的总数除以访问的字节总数(也称为MOPs或内存操作)来计算。
- 计算强度上限 \({I_{max} = \frac{\pi}{\beta}}\)。它描述的是在这个计算平台上,单位内存交换最多用来进行多少次计算。单位是FLOPs/Byte。计算带宽和内存带宽这两个指标相除即可得到计算平台的计算强度上限。
- 模型的理论性能 \(P\):模型在计算平台上所能达到的每秒浮点运算次数(理论值)。单位是FLOPSorFLOP/s。
计算受限与内存受限
程序的执行时间主要花在两个地方:计算和读写数据。因此我们得到以下两个时间。
\[计算时间 = \frac{计算次数}{计算带宽} \\访存时间 = \frac{内存访问量}{内存带宽}\]
一般来说,计算时间和访存时间可以重叠,即“一边计算,一边读/写下一个”,因此总的运行时间为\(max(计算时间,访存时间)\)。
- 计算受限(math-bound)。当计算时间大于访存时间,即完成某操作的大部分时间是在GPU的流多处理器上计算(GPU执行块状并行计算),就说明计算带宽是算法的瓶颈。读得快,算得慢,这就是计算受限(math-bound)。此时HBM访问所花费的时间相对较低,不管模型的计算强度有多大,它的理论性能最大只能等于计算平台的算力。比如:大矩阵乘法、通道数很大的卷积运算。
- 内存受限(memory-bound)。当访存时间大于计算时间,即完成某操作的部分时间是将数据从内存移动到流多处理器(而不是实际在流多处理器上计算),就说明内存带宽是算法的瓶颈。算得快,读得慢,这就是内存受限(memory-bound)。当模型的计算强度 小于计算平台的计算强度上限时,此时模型理论性能的大小完全由计算平台的带宽上限以及模型自身的计算强度决定。逐点运算的操作大多是内存受限的,比如:激活函数、dropout、mask;另外规约类(reduction)操作也是内存受限的,比如:sum,softmax,batch normalization和layer normalization。
注意力机制的计算强度
为了评估Transformer中的瓶颈,需要对计算Transformer仅编码器和仅解码器模型所需的浮点运算(FLOPs)数量以及这些网络的算术强度进行建模。注意力机制的计算过程中最重要的部分是计算注意力权重,我们来看看其计算强度。假定有 \(Q,K \in R^{N \times d}\),计算 \(P=QK^T \in R ^{N \times N}\),\(O=PV \in R ^{N \times N}\),其中d是注意力头维度。参考下面图例,得出注意力权重的计算强度如下:
\[ops/bytes = \frac{4N^2d}{2Nd + 2Nd + 4N^2} = \frac{4N^2d}{4Nd + 4N^2} = \frac{N^2d}{Nd + N^2}\]
注意:有的论文或者博客省略了第3,4步,所以计算MAC会和本文不同。
矩阵乘法是计算受限还是内存受限,取决于这个公式和所在平台计算强度$I_{max} \(的比较结果。A100-40GB SXM的平台计算强度\)I_{max} $为201 flops/bytes。因此,如果矩阵乘法的计算强度大于201,此时的性能受限于计算带宽;反之,性能受限于内存带宽。而GPU的计算速度会“远快于”显存带宽。因此,对于注意力机制这类访存密集型任务,决定生成速度的不是GPU的计算能力,而是显存的带宽。另外,注意力机制中的一些操作也是内存受限的逐点运算,比如对S的mask操作、softmax操作和对P的dropout操作,这些逐点操作的性能也受限于内存带宽。
如何平衡
有研究人员对BERT Base和BERT Large编码器以及GPT-2解码器在不同序列长度上的算术强度进行分析。
- 对于短序列长度(例如128-512),大多数计算在FFN模块的投影层中,而MHA计算的大部分在投影层中。
- 随着序列长度的增加,矩阵乘法开始占主导地位,因为它们都是按序列长度二次缩放的。这导致算术强度在起初会增加,因为较大的矩阵维度允许每个加载的参数执行更多的计算。
- 然而,在较高的序列长度下,算术强度会降低。这是因为,对于长序列长度,MHA模块的矩阵乘法和Softmax计算开始占主导地位。与FFN模块中的投影层相比,这些具有相对较低的算术强度。
这些观察结果证实了,解码器推理是一个内存约束问题,而不是计算约束问题。那么要平衡利用 GPU 算力和内存带宽,batch size 需要是多少呢?其计算公式是 2 byte * 参数量 / 卡的数量 / 内存带宽 = batch size * 2 * 参数量 / 卡的数量 / 算力。等式左右两边参数量和卡的数量互相抵消,最终得到 batch size = 算力 / 内存带宽。这就需要依据不同芯片的参数来进行调节。另外,也要考虑网络延迟以及通信库本身的开销。
1.3 Tiling
Tiling(平铺)是一种通过分割输入和维护一些中间变量来递推式地完成操作,从而减少内存消耗的技术。这种平铺方法是有效的原因是:加法是关联的,允许将整个矩阵乘法分解为许多平铺矩阵乘法的总和。
对于大矩阵,如果对整个矩阵直接进行操作,则会消耗巨大的内存。我们知道矩阵乘具有分块和累加的特性,因此一个大的矩阵乘法可以通过Tiling技术来分解成更小的子矩阵,然后分别把这些小矩阵从慢速HBM加载到快速SRAM,在SRAM中对这些小矩阵进行计算,最后再把各个分块矩阵乘的结果进行累加获得最后的正确结果。
下图简要解释了如何对矩阵乘法\(C=A \times B\)的输入和输出矩阵进行划分。每个矩阵被划分为\(T \times T\)分片。对于每个输出分片,我们从左到右扫描A中的相关分片,从上到下扫描B中的相关分片,并将值从全局内存加载到片上内存(颜色为蓝色,整个片上内存占用面积为\(O(T^2)\))。对于位置(i,j),我们从片上存储器为分片内的所有k来加载A[i,k]和B[k,j](用红色表示),然后在片上存储器中将\(A[i,k]\times B[k,j]\)聚合到C[i,j]。在一个分片的通信完成后,我们将片上C分片写回主存,然后继续处理下一个分片。
另外,我们也可以将计算所需的数据提前或者异步的方式从HBM加载到SRAM,结合流水线编排就可以进一步隐藏掉数据加载所需时间。
该操作对于的伪代码如下:- a = A_i
- b = B_j
- c = C_ij
- for k in range(k):
- c += a[k] * b[k]
- final c done
复制代码 1.4 算子融合
在推理引擎实现中,对于性能受限于内存带宽的操作进行加速的常用方式就是算子融合,其基本思想是:在SRAM存储容许的情况下,将多个操作融合成一个操作来完成,从而避免反复执行“从HBM中读取输入数据,执行计算,将计算结果写入到HBM中”。
我们通过实例来进行分析。假设要连续执行算子A和算子B,其中算子A的输出是算子B的输入。最朴素的执行顺序如下:
- 启动算子A,把A所需要的数据从HBM拷贝到SRAM。
- 运行算子A。
- 把算子A的结果写回到HBM。
- 启动算子B,把B需要的数据从HBM拷贝到SRAM。
- 运行算子B。
- 把算子B的结果写回到HBM。
这个序列涉及到四次读写HBM操作和两次启动算子操作,会造成运行时间增加。
在算子融合的思路下,如果发现SRAM完全有能力存下算子A的输出结果,我们会把算子A和算子B合并成一个操作。这样A的输出就直接暂存在SRAM中让B来读取,从而可以减少读写HBM的次数,启动算子的动作等,从而有效减少内存受限操作的运行时间。
0x02 优化注意力机制
因为FlashAttention优化了注意力计算过程中的访存(HBM)的过程。所以我们先来看下标准注意力机制的计算访存。
2.1 标准注意力机制
计算公式
回顾缩放点积注意力(Scaled Dot-Product Attention)模块的公式如下:
\[Attention(Q,K,V) = softmax( \frac{QK^T}{\sqrt d_k} ) \times V\]
这个公式中,Q和K的维度均是\((N,d_k)\),V的维度是\((N,d_v)\),其中\(N\)是输入序列长度,\(d_k,d_v\)是特征维度。\(softmax(QK^T)\)的维度是\((N,N)\),\(Attention(Q,K,V)\)的输出维度是\((N,d_v)\)。
为了描述方便,后续在讨论中省略了Mask和Scale。由于多头注意力各个头的计算逻辑是一致的。这里也只描述单个头的情况。因此,假设一共有 N 个token,每个token向量的维度为 d ,则一个简化版注意力计算过程如下图:
实现算法
FlashAttention论文中给出的标准注意力机制的实现算法如下图所示。算法具体分成三步(也叫做3-pass算法):
- \(S=QK^T\)(计算注意力分数)。\(QK^T\)目的是获得每个query相对于所有key的点积。直观上,点积越大,某个Q行和某个\(K^T\)的列的相关性就大。具体操作时,注意力机制会从HBM中加载\(Q,K\)矩阵,执行计算点积\(S=QK^T\)的操作得到相似度得分\(S\),再将结果\(S\)写回HBM。
- \(P=softmax(S)\)(计算注意力权重)。softmax操作的目的是对注意力分数进行归一化。具体操作是将\(S\)从HBM中读取出来,执行\(P=softmax(S)\)的计算得到注意力权重,再将\(P\)写回HBM。
- \(O=PV\)(计算最终注意力结果)。将\(P\)和\(V\)从HBM中读取出来,执行\(O=PV\)的计算,最后把向量\(O\)写回HBM中。
注:算法中省略了mask和dropout操作,Q,、K、V、O都是2D矩阵,形状为(N,d)。N为序列长度,d为注意力头维度。
我们将上述算法用图例展示如下。
细化拆解
上面的图没有展示出SRAM和HBM之间的交互,我们从其他论文中找出更加详细的算法实现如下。
下图展示了算法中SRAM和HBM之间的交互流程和读写的数据量大小,图中的序号和上面算法的序号一致。
注意:有的论文或者博客省略了第3,4步,所以计算MAC会和本文不同。
问题所在
标准注意力算法在GPU内存分级存储的架构下存在两个缺陷:显存占用多和HBM读写次数多。造成缺陷的罪魁祸是\(QK^⊤\)操作。该操作一方面决定了注意力机制的算法复杂度是\(O(N^2)\),另一方面其产生的两个中间矩阵S和P的内存占用过大,需要在HBM和SRAM中搬运,而 HBM 的读写带宽 相比 SRAM 低很多,于是减慢了运行时间(wall-clock time)。我们接下来一一进行分析。
- 显存占用多。3-pass算法的输入和输出变量Q,K,V,O 所需要的内存为$ O(Nd)$ ,步骤一和步骤二会分别产生两个中间矩阵S和P,内存需求均是\(O(N^2)\),因此总内存需求是\(O(N^2+Nd)\)。当序列长度N很大(即 N≫d)时P和S 需要的内存 $O(N^2) \(远大于 Q,K,V,O 所需要的内存\) O(Nd)$ ,这样会耗尽显存,同时GPU HBM的访存压力也会急剧变大为\(O(Nd+N^2)\)。
- HBM读写次数多。因为中间矩阵内存占用过大,无法被SRAM容纳,因此需要从SRAM转移到HBM中。但是因为计算需要,S和P在存入HBM后又立即被访问,所以导致多次读写HBM操作。3-pass算法的三个步骤分别对应三个kernel(具体在算法图中有标明):gemm、softmax和gemm。三个kernel依次执行。每个kernel的计算过程都存在如下操作:从HBM读取数据;计算;写回HBM。一共包含八次HBM的矩阵读写操作,总HBM访问次数为\(O(Nd+N^2))\)。具体八次操作分别为:
- 第一步有三次操作。两次读操作为从HBM中读取完整的Q和K矩阵(每个大小为\(R^{N×d}\) ),一次写操作为把相似度得分S(大小为\(R^{N×N}\) )写回到HBM。总共需要进行\(O(Nd + N^2)\)次HBM访问,其中涉及到一次超大矩阵S的读取。
- 第二步有两次操作。一次读操作为从HBM中读取完整的S矩阵,一次写操作为把P(大小为\(R^{N×N}\) )回写到HBM。总共需要进行\(O(N^2)\)次HBM访问,而且涉及到两次超大矩阵的读写。
- 第三步有三次操作。两次读操作为从HBM中读取完整的P和V矩阵(大小为\(R^{N×d}\)),一次写操作为把输出向量O(大小为\(R^{N×d}\))写回到HBM,总共需要进行\(O(Nd)\)次HBM访问,其中涉及到一次超大矩阵P的读取。
2.2 解决方案
既然知道了\(QK^⊤\)操作是罪魁祸首,我们就思考下如何把计算过程中间结果所需的内存空间减低,让中间结果可以暂存在SRAM中,从而减少I/O读写,优化IO时间。
思路
我们的目标是计算O,一般来说,我们需要获取所有的Q,K,V,然后分三步计算;我们也可以先获取一小块Q,K,V,一次计算得到部分的O,再想办法将部分的O合成全部的O。
前面提到,注意力机制(\(softmax( \frac{QK^T}{\sqrt d_k} ) \times V\))的三个主要计算模块为计算注意力分数,归一化和根据注意力权重的加权求和,分别对应依次执行的三个kernel:gemm(query×key)、point-wise的softmax、gemm(attn_score×value)。如果SRAM可以存储中间结果,我们将这三个kernel融合起来,让中间结果数据停留在SRAM上面,就会避免重复从HBM上读写中间全局内存,从而达到对 pointwise 操作加速的目的。注,我们暂时抛开softmax计算的特殊性,假设其可以融合。
因此,我们的总体方案是:用“融合+分块”来避免频繁从HBM读写大型矩阵。即抹去对大型矩阵S,P的读写。融合+分块是一个硬币的两面,互相交织,需要统一对方案思路进行分析,即:
- 因为要减少IO,所以要以两个gemm kernel为中心来进行算子融合。
- 融合的前提是要把所有中间变量都存起来,不写回HBM;
- 而SRAM没有这么大空间来容纳中间矩阵,因此就需要做融合时候考虑分块。只要分块矩阵和中间注意力结果可以在SRAM内存放,就可以在计算过程中只访问SRAM了。
我们接下来就看看算子融合和分块计算。
算子融合
针对注意力计算,我们的思路就是:针对数据的换入换出进行优化,把两个gemm和softmax融合成一个算子:\(softmax( QK^T) \times V\)一次性在SRAM中处理,从而减少S和P的读写。
标准注意力的算法是:在 SRAM 上计算 \(S=QK^T\) ,将矩阵 S 写入到 HBM 中,然后再将矩阵 S 从 HBM 读入到 SRAM 中,计算 P=softmax(S) 。
在算子融合方案下,上述操作可以合并在一个 kernel 中完成,即在 SRAM 中计算完 S 之后紧接着就通过 S 计算 P ,这样就可以避免在 HBM 和 SRAM 交换 S 。
分块计算
前面提到,算子融合的前提是SRAM存储足够大,或者说,只有SRAM能够容纳中间结果,才有算子融合的可行性。这是因为虽然算子融合有效,但是解决不了内存开销太大的问题。
比如下图中SRAM能容纳10000个数据,但是Q和K都是5000个数据。如果一次性运行融合算子\(softmax( QK^T) \times V\),则需要加载10000个数据到SRAM,但是这样就无法容纳中间计算结果,会造成OOM,因此只能通过迭代方式进行计算,依然导致大量对HBM的读写操作。
因为SRAM的内存大小有限,不可能一次性计算完整的注意力。而全连接层和根据注意力权重的加权求和其实都是通过矩阵乘法实现的,因此可以通过tiling操作来进行分块计算。在分块计算中只加载必要的参与计算的Q,K,V的分块到SRAM ,这样其总体内存不超过SRAM的大小,并且计算完成S后,直接使用S来计算P。借此来提高整体读写速度(减少了HBM访问次数)。具体如下图所示
- 将Q [100,50] 切分成两个矩阵
- 将 K [100,50] 切分成两个矩阵
- 此时\(softmax( QK^T) \times V\)算子可以在SRAM 一次性算完这些小块的注意力操作。
因此,我们得到了总体思路如下:QK^T 生成了一个形状为 (b, n, s, s) 的临时输出,而我们只需要 Softmax(QK^T)V 的最终结果,其形状为 (b, n, s, d)。只要 s 和 d 相对较小,我们就可以将这三个矩阵的乘法融合成一个单独的Cuda核(Kernel)函数,直接产生 Softmax(QK^T)V。
限制
看过了Softmax(QK^T)V 得大致思路,我们再仔细看看如何计算O以及的SRAM的限制,这里做几点说明。
- 生成\(O_j\)是累加更新操作。我们以\(O_1\)为例进行说明。
- \(O_1\)是从\(O_1 = Q1K1V1\)一直累积更新,最终得到\(O_1 = Q1K1V1 + Q1K2V2 + Q1K3V3 + Q1K4V4\)。
- 为了更好的分析,我们把\(O_j\)看做是包含 i 个元素的一行向量,即把加法的每项看作是一个元素(如下图所示,\(O_1^2\)和\(O_1^3\)是两个元素),即\(O_1^2\)是第一行的第二列。
- 每次更新\(O_1\)需要把\(O_1\)前面一列加载到SRAM,然后才可以对\(O_1\)进行增加新列的操作。
- 省略softmax操作,即kernel函数是计算 (QK^T)V。
- SRAM每次只能够容纳Q、K、V、O的小块。
如何计算O是难点所在,我们接下来看看两个方案。
方案1
我们得到方案1的逻辑方案如下图,具体思路是:
首先将K、V切成了Tc个小块,将Q和O切分为Tr个小块。
接下来开始进行循环计算。j 是外循环, i 是内循环,或者说K和V是外循环j,Q和O是内循环i。
外循环逻辑如下:
- 外层第j次循环拿到了K矩阵和V矩阵的第j个块 \(K_j\),\(V_j\),加载到SRAM中。
- 每次外循环都对 \(O_{1}\)到\(O_{tr}\)全部进行更新,但是每次分别只更新 \(O_{1}\)到\(O_{tr}\)的一部分。
- 最终所有j循环结束后,得到的最新的完整O就是期望的结果, \(O_{1}\)到\(O_{tr}\)是直到外循环结束之后才一次性全部更新完成。
第j个外循环的内循环 i 会逐行更新\(O\)的每一行$$O_i$$,其逻辑如下:
- 把Q矩阵的第i个块\(Q_i\)和O矩阵的第i个块\(O_i\)(即\(O_i\)行的前一个状态,可以简单理解为\(O_i\)的前一列\(O_i^{j-1}\))加载到SRAM。
- 用\(K_j\)和\(Q_i\)计算得到了S和P,再和\(V_j\)相乘得到了\(O_i\)行的新一列$$O_i^j$$。
- 用\(O_i^{j-1}\)和\(O_i^j\)累积,更新\(O_i\)。
- 把\(O_i\)回写到HBM。
- 内循环期间一共对O进行\(Tr\)次更新。
写成伪代码如下。
[code]# ---------------------# Tc: K和V的分块数# Tr: Q和O的分块数量# ---------------------O_0 = 0for 1 dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的。</p>因为我们只关注了矩阵乘,所以目前看起来一切美好。然而我们的前进路上还有一个拦路虎:softmax。制约注意力机制性能的关键因素,其实是Softmax。我们接下来就看看softmax的问题所在,以及如何解决。
0x03 Softmax改进
我们首先从原生softmax开始看看其存在的问题,以及如何改进。先给出问题概述:矩阵是可加的,但是softmax是不可加的。即,Self-Attention 包含一个不直接关联的 softmax 运算符,因此很难简单地平铺 Self-Attention。
3.1 原生softmax
Softmax 函数是一种常用于机器学习,特别是多分类问题中的激活函数。它的作用是将一个任意实数向量转换为一个概率分布,并确保输出的概率和为 1。
公式
假设某一数组是 \([x_1, x_2, ..., x_V]\),\(x_i\)是数组中某一个元素,原生softmax的计算公式如下:
\[softmax(x_i) = \frac{e^{x_i}}{\sum _{j=1}^V e^{x_j}}\]
具体算法如下图所示,算法流程需要两个循环,涉及两次从内存读取和一次写回内存操作:
<ul>计算归一化项(normalization term) \(d_V\)。Softmax 函数中,分母的求和项被叫做归一化项 \(
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |