找回密码
 立即注册
首页 业界区 业界 探秘Transformer系列之(20)--- KV Cache

探秘Transformer系列之(20)--- KV Cache

琉艺戕 5 天前
探秘Transformer系列之(20)--- KV Cache


目录

  • 探秘Transformer系列之(20)--- KV Cache

    • 0x00 概述
    • 0x01 自回归推理的问题

      • 1.1 请求的生命周期
      • 1.2 简化推导
      • 1.3 冗余分析
      • 1.4 冗余根源



          • 1.4.1 看处理逻辑
          • 1.4.2 看处理过程


      • 1.5 如何改进

        • 1.5.1 从网络角度看
        • 1.5.2 从数学角度看
        • 1.5.3 结论


    • 0x02 用KV Cache来优化

      • 2.1 术语
      • 2.2 流程
      • 2.3 重新定义阶段

        • 2.3.1 定义
        • 2.3.2 分析

      • 2.4 思考

        • 2.4.1 历史上下文
        • 2.4.2 Q其实也被缓存了
        • 2.4.3 每层都有独立的KV Cache
        • 2.4.4 计算机架构
        • 2.4.5 适用前提


    • 0x03 实现

      • 3.1 总体思路
      • 3.2 存储结构

        • 3.2.1 llama3
        • 3.2.2 Transformer库

      • 3.3 如何使用

    • 0x04 资源占用

      • 4.1 维度变化
      • 4.2 存储量

        • 4.2.1 单层
        • 4.2.2 多层
        • 4.2.3 实际样例
        • 4.2.4 存储实现

      • 4.3 计算量

        • 4.3.1 查表
        • 4.3.2 \(W^Q, W^K, W^V\)计算
        • 4.3.3 Attention
        • 4.3.4 MLP
        • 4.3.5 对比

          • 没有KV cache时
          • KV Cache

        • 小结

      • 4.4 总结

    • 0xFF 参考


0x00 概述

随着输入给LLM的token列表增长,Transformer的自注意力阶段可能成为性能瓶颈。token列表越长,意味着相乘的矩阵越大。每次矩阵乘法都由许多较小的数值运算组成,这些运算称为浮点运算,其性能受限于GPU的每秒浮点运算能力(FLOPS)。这样,在LLM的部署过程中,推理延迟和吞吐量问题成为了亟待解决的难题。这些问题主要源于:

  • 生成推理的序列自回归特性,需要为所有先前的标记重新计算键和值向量。
  • 由于注意力机制与输入序列的大小呈二次方关系增长,因此在推理过程中,注意力机制往往会产生最大的延迟开销。
为解决推理延迟和吞吐量问题,最常用的优化技术是KV Cache。KV Cache是一种关键的性能优化机制。它通过缓存已计算的Key和Value矩阵,避免在自回归生成过程中重复计算,从而显著提升推理效率(本质就是用空间换时间)。这种机制类似于人类思维中的短期记忆系统,使模型能够高效地利用历史信息。通过复用 KV Cache,可以达到两大目的:

  • 提升 Prefill 效率。由于参与 Prefill 的 Tokens 数减少,所以计算量下降,Prefill 的延时也就下降,直接提升 TTFT 性能。特别适合优化多轮对话场景的性能。
  • 节省显存。KV缓存中存储了生成推理过程中至关重要的可重用中间数据。
本篇先介绍在不使用 KV Cache 的情况下是如何一步步预测下一个 token 的,然后介绍 KV Cache。
注意:本文的分析梳理可能与实际概念产生历史轨迹不同,这么梳理只是因为作者觉得这样更容易解释。
0x01 自回归推理的问题

多轮对话是现代大型语言模型(LLM)的基本功能。在这种对话中,一个多轮对话会话由一系列连续的对话组成,记作D = [d1, d2, ... dN]。在每个对话dj中,用户输入一个新的问题或命令qj,然后等待LLM的响应aj。
LLM使用的是自回归模式。自回归模型的推理过程很有特点:推理生成 tokens 的过程是迭代式的。用前文预测下一个字/词,并且前文中的最后一个词经过解码器的表征会映射为其下一个待预测词的概率分布。具体来说是,我们给定一个输入文本,模型会输出一个回答(长度为N)。但实际上该过程中执行了N次推理过程。即一次推理只输出一个token,当前轮输出的 token 会与之前输入 tokens 拼接在一起,并作为下一轮的输入 tokens,这样不断反复直到遇到终止符或生成的 token 数目达到设置的 max_new_token 才会停止。
1.jpeg

1.1 请求的生命周期

实际上对LLM的使用中,prompt都是较长的序列。在不考虑KV Cache的情况下,因为prompt的实际特点,导致LLM推理过程中存在着prompt phase(提示处理)和 token-generation phase(token生成)这两个截然不同的过程。

  • prompt phase:LM服务接受到用户请求(Is tomato a fruit?),根据输入 Tokens(Is, tomato, a, fruit, ?) 生成第一个输出 Token(Yes)。
  • token-generation phase:从生成第一个 Token(Processing) 之后开始,把 prompt 以及已生成的 tokens 组成新的模型输入,采用自回归方式一次生成一个 Token,直到生成一个特殊的 Stop Token(或者满足用户的某个条件,比如超过特定长度) 才会结束。该过程中,前后两轮的输入只相差一个 token,存在重复计算。
prompt phase整体算1个推理阶段, token-generation phase中的每个decode各算1个推理阶段,比如下图 token-generation phase阶段包括3次推理。
2.jpeg

我们对两个阶段的特点进行深入分析。
prompt phase(预填充阶段),也有叫启动阶段(initiation phase),其特点如下:

  • 时机:发生在计算第一个输出 token 过程中。
  • 输入:输入一个prompt序列。
  • 作用:一次性处理所有的用户输入。LLMs对输入序列(即输入提示)的上下文进行总结,并生成一个新标记作为解码阶段的初始输入。
  • 执行次数:其通过一次 Forward 就可以完成。
  • 计算类型:存在大量 GEMM (GEneral Matrix-Matrix multiply) 操作,属于 Compute-bound 类型(计算密集型)计算。
  • 并行:输入的Tokens之间以并行方式执行运算,是一种高度并行化的矩阵操作,具备比较高的执行效率。
token-generation phase的特点如下:

  • 时机:在prompt阶段生成第一个 Token之后,开始进入token-generation phase阶段。发生在计算第二个输出 token 至最后一个 token 过程中。
  • 输入:新生成的token会与输入tokens 拼接在一起,作为下一次推理的输入。
  • 作用:新生成的标记被反馈回解码阶段作为输入,从而创建了一个用于标记生成的自回归过程。
  • 执行次数:假设输出总共有 N 个 Token,则 token-generation phase阶段需要执行 N-1 次 Forward。
  • 计算类型:存在大量 GEMM (GEneral Matrix-Matrix multiply) 操作,属于 Compute-bound 类型(计算密集型)计算。
  • 并行:假设输出总共有 N 个 Token,则 Decoding 阶段需要执行 N-1 次 Forward,这 N-1 次 Forward 只能串行执行,因此效率相对比较低。另外,在生成过程中,需要关注的 Token 越来越多(每个 Token 的生成都需要 Attention 之前的 Token),计算量也会适当增大。
自回归的生成模式是两阶段的根本原因,两阶段是自回归的生成模式的外在体现形式,KV cache是优化手段。
注:在SplitWise论文中,分别把这两个阶段称为prompt phase 和 token-generation phase。在实践中,“预填充(pre-fill)”和“初始化(initiation)”这两个术语可以互换。为了更好的说明,现在我们将更倾向于使用前者。
1.2 简化推导

我们用实例来看看LLM类模型对于给定文本的回答过程。为了更好的梳理,此处的prompt只是一个词(与实际情况不符)。我们可以将回答过程分解为下列推理:输入“新”,模型逐步预测出“年”,“大”,“吉”,[EOS]这几个词。具体推理步骤如下。
  1. 第一次推理: 输入=[BOS]新;输出=年
  2. 第二次推理: 输入=[BOS]新年;输出=大
  3. 第三次推理: 输入=[BOS]新年大;输出=吉
  4. 第四次推理: 输入=[BOS]新年大吉;输出=[EOS]
复制代码
其中[BOS]和[EOS]分别是起始符号和终止符号。
3.jpeg

我们接下来深入到Transformer内部逐一看看上述推理流程。注意:下面的示例图只给出了和 KV Cache 相关的细节。
第一步输入“新”,输出“年"。本步骤具体数据流如下图所示。
4.jpeg

第二步会将”年“拼接到”新“的后面作为新的输入,即本次推理的输入为”新年“,预测得到”快“。本步骤具体数据流如下图所示。
5.jpeg

第三步会将”快“拼接到”新年“的后面作为新的输入,即本次推理的输入为”新年快“,预测得到”乐“。本步骤具体数据流如下图所示。
6.jpeg

1.3 冗余分析

我们把上面三步汇总起来如下图所示。会发现其中存在大量的冗余计算,每生成一个token需重新计算所有历史token的Key/Value,复杂度为 \(O(n^2)\) ,显存和计算时间随序列长度急剧增长,比如:

  • 生成embedding有冗余计算。
  • KV生成有冗余计算。
  • \(QK^T\)有冗余计算。
  • softmax操作以及与V相乘有冗余计算。
7.jpeg

因为每一步中前面的操作都是为计算注意力做准备,因此我们针对注意力部分进行重点分析。每一步中涉及注意力的计算如下(下面的\(\theta\)指代softmax操作后的结果,比如第二步中,\(\theta(Q_2K_1^T)\)可能是0.4,\(\theta(Q_2K_2^T)\)可能是0.6)。

  • 第一步涉及的计算为:\(\theta(Q_1K_1^T)V_1\)。
  • 第二步涉及的计算为:\(\theta(Q_1K_1^T)V_1\),\(\theta(Q_2K_1^T)V_1 + \theta(Q_2K_2^T)V_2\)。

    • 有一步重复计算\(\theta(Q_1K_1^T)V_1\),这步重复计算仅仅依赖于\(Q_1K_1V_1\),和\(Q_2K_2V_2\)没有关系。
    • \(V_2\)的计算是新增计算,从\(\theta(Q_2K_1^T)V_1 + \theta(Q_2K_2^T)V_2\)中可以看到,\(V_2\)的计算仅与\(Q_2\)相关,与\(Q_1\)无关。

  • 第三步涉及的计算为:\(\theta(Q_1K_1^T)V_1\),\(\theta(Q_2K_1^T)V_1 + \theta(Q_2K_2^T)V_2\),\(\theta(Q_3K_1^T)V_1 + \theta(Q_3K_2^T)V_2 + \theta(Q_3K_3^T)V_3\)。

    • 有两步重复计算,具体道理和第二步类似。
    • \(V_3\)的计算是新增计算,其仅与\(Q_3\)相关,与\(Q_1\),\(Q_2\)无关。

看起来,在预测第i个字时,只有最后一步引入了新的计算,而第1个到第i-1步的计算和前面是完全重复的。
1.4 冗余根源

现在我们探寻冗余计算的原因,即为什么之前的词不需要重复计算。
1.4.1 看处理逻辑

为了生成与上下文紧密相关的新标记,LLMs需要在注意力层中计算最后一个token与所有之前token(包括输入序列中的token)之间的关系。一种简单的方法是在每个迭代中重新计算所有之前标记的键和值。因此每一步中,当前轮输出token与输入tokens拼接作为下一轮的输入tokens。第i+1轮输入数据只比第i轮输入数据新增了一个token,其他全部相同。然而,这样第i+1轮推理时必然包含了第 i轮的部分计算,再对前面的单词做计算就是冗余。而且计算开销随着之前标记数量的增加而线性增长,即对于更长的序列,开销会更大。
对于每次token生成,其查询是从当前token计算出来的,而键和值是从所有token派生出来的,并且对于后续token不会更改。vanilla Transformer的实现会在生成每个新token时重新计算键和值们,从而不必要地增加了 GPU 每个注意力块所需的计算量。
1.4.2 看处理过程

从网络结构来看,Transformer的主要模块决定了不需要重复计算:

  • 注意力模块(对应下图中标号1)。

    • 推理时,前面生成的token看不到后续生成的token,所以前面已经生成的 token不需要与后面的 token进行注意力计算。在“单向 attention”的影响下,序列预测过程的第 i 个时间步的 query 向量 \(q_i\) 不会影响前序所有时间步的 \([k_1, k_2,..., k_{i-1}]\) 和\([v_1, v_2,..., v_{i-1}]\) 。比如, i=3 时的 \(k_2\) 和 i=4 时的\(k_2\) 完全相同。在 Transformer 的每一层,Key 和 Value 都不会被重复计算。
    • 训练时,由于掩码技术的使用,在生成当前 tokens 的输出表征时,仅使用之前已生成 tokens 的信息,而不使用之后生成的 tokens 的信息。即\(Q_i\)与\(K_{i+j}\),\(V_{i+j}\)的计算会被mask掉,不需要计算。掩码的主要优点是将(自)注意力机制的FLOPs需求从与总序列长度呈二次方扩展变为线性扩展。在每个生成步骤中,我们实际上可以避免重新计算过去token的键和值,而只需计算最后生成的token。每次计算新的键和值时,我们的确可以将它们缓存到GPU内存中以供未来重复使用,因此节省了重新计算它们时所需的浮点运算次数。

  • FFN(对应下图中标号2)。在FFN计算中,序列中各个词对应的特征不会交互信息,不会互相影响,并且最终只取最后一个位置的输出特征作为下一个token的概率分布。因此,经过FNN层后,第 i 个输出的新增计算只和第 i 个输入有关,和其他输入无关,比如下面\(Y_1\)的计算只和\(X_1\)相关。

    \[\begin{bmatrix}   X_0 \\   X_1 \\  X_2 \\   X_3 \\ \end{bmatrix}W^T = \begin{bmatrix}   X_0 W^T\\   X_1 W^T\\  X_2 W^T\\   X_3 W^T\\\end{bmatrix} = \begin{bmatrix}   Y_0 \\   Y_1 \\ Y_2 \\ Y_3 \\  \end{bmatrix}\]

    • Add & Norm(对应下图中标号3)。对于LayerNorm,它是在 d_model 方向上计算均值和方差,然后进行归一化,因此它的输出也只与输入 hidden_state 的最后一行相关。
    • Linear(对应下图中标号4)。这是一个将 hidden_state 的维度从 d_model 变换到 vocab_size的线性映射,根据矩阵乘法的性质,可以知道 logits 的最后一行只与 hidden_state 的最后一行相关。
    • Softmax(对应下图中标号5)。softmax只要把之前的计算结果存储起来,就可以结合新计算的结果来进行计算。

8.jpeg

1.5 如何改进

虽然我们推导出来有冗余计算,但是vanilla Transformer在推理的时候可不管这些,无论你是不是只要最后一个字的输出,它都把所有输入计算一遍,导致输出结果中间有很多我们用不到的计算,这样就造成了浪费。这就是问题所在。因此我们要看看如何改进。因为涉及到对某些和前文相关的中间变量进行缓存或者丢弃,我们需要仔细斟酌究竟缓存哪些、丢弃哪些。
1.5.1 从网络角度看

我们从模型架构来看看几种选择方式。
选择结论原因丢弃前面的X(输入的token)不行下面详细解释缓存X可以,但不是最优选择因为即便缓存了X,还需要计算K和V缓存\(QK^T\)不行实际计算下一个token时候并没有使用到之前的\(QK^T\)丢弃之前的query可以模型的第i个输出只和query'的第 i 个token有关,和其他query无关,新增计算只和当前\(Q_i\)关联,但是和之前的\(Q_{0,i-1}\)没有关联,所以完全没有必要缓存之前的query。丢弃之前的KV不行下面详细解释缓存之前的KV可以下面详细解释为何不能丢弃前面的输入token
我们知道,推理最终只会选取最后一个位置的输出特征作为下一个token的概率分布,即下一个token是由当前最后一个token的网络输出所决定的。但这不代表可以仅输入最后一个token来进行推理。因为虽然在结果层仅由最后一个token来决定,但是中间的注意力过程依赖于前文所提供的Key、Value向量来携带前文信息,因此也不能抛弃前文不管。
或者说,由X生成Q、K、V三个分支,因为前面的K和V不能丢弃。所以不能单纯丢弃前面的X。但是由于Q在自回归Transformer模型中的使用特性和计算过程中的不对称性,缓存Q不会带来推理效率的提升,因此LLM推理过程中通常不缓存Q。
当然,因为X派生了K和V,如果缓存K和V,就可以丢弃输入X。
为何不能丢弃之前的KV
前面提到了KV不可或缺。我们接下来再深入分析。
在注意力机制中,第 i 个输出 $O_i \((可以拓展到每个transformer block的输出)和完整的K、V以及当前时刻的\)Q_i\(都有关。我们以第二步计算为例:红圈表示\)O_0\(计算所涉及的元素,蓝圈表示\)O_1$计算所涉及的元素。可以看到蓝圈涉及到所有K和V。
9.jpeg

我们再用高阶向量来细化到具体运算,从下图可以看到,\(O_3\)的计算涉及所有的QKV。
10.jpeg

缓存之前KV的可行性
既然之前的KV是必需的,我们接下来就看看缓存的可行性。

  • 首先,K、V的历史值只和历史的O有关,和当前的O无关,从这个角度看可以缓存K和V。
  • 其次,先前的token在后续迭代过程中保持不变,因此对于该特定token的输出表征对于所有后续迭代也将是相同的。在推理时,模型的权重已经固定(\(W^Q\),\(W_K\),\(W^V\)的权重固定),对于同一个词,如果它的Token Embedding和位置编码都是固定的,则从\(W^Q\),\(W_K\),\(W^V\)计算得到的Q,K,V是固定的。因此计算一次即可。
因此,我们可以通过缓存历史的K、V来避免重复计算历史K、V。
1.5.2 从数学角度看

假设矩阵A和矩阵B相乘,我们将矩阵A拆分为[:s], 两部分,分别和矩阵B相乘,那么最终结果可以直接拼接,该结果与不分拆结果一致。注意力和FFN都是矩阵乘法操作,因此将[:s]部分缓存,来避免[:]整体输入导致的重复计算。
11.jpeg

1.5.3 结论

以上的分析证明了缓存KV再拼接计算的结果和正常的输入全序列计算是等价的,但是计算量大大减少了,这就是KV Cache。
0x02 用KV Cache来优化

KV Cache 的想法很直观:用空间换时间,缓存上一轮的 K, V,从而避免每次生成token时重新计算key、value向量,利用预先计算好的key值和value值就可以生成新token,这样可达到减少计算,提速的效果。KV Cache的大体作用如下。

  • KV Cache充当自回归生成模型的内存库,来存储所有之前标记的键(K)和值(V),以便将来重复使用,保证KV是全的。
  • 每次迭代计算新的键向量和值向量时,KV缓存都会更新生成的标记的键和值。
  • 模型的第一次输入是完整的prompt,后续输入只有上一次推理生成的 token,而不是整个 prompt 序列。
  • 当计算第 K+1 个token的注意力分数时,模型不需要重新计算所有先前K个token的键和值,而仅需从缓存中检索先前K个token的键和值并串接至当前向量。
2.1 术语


我们首先看看KV-cache的结构和术语。LLM由多个transformer块层组成,每个层都维护其自己的键和值的缓存。在本文中,我们将所有transformer块的缓存统称为KV-cache,同时使用术语K-cache或V-cache分别表示键和值。在深度学习框架中,每个层的K-cache(或V-cache)通常表示为形状为[
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册