大模型KV Cache产生背景及技术细节


本文将针对仅编码器Transformer架构(Decoder-Only Transformer)的模型必备显存优化技术 KV Cache 进行讲解。

大模型KV Cache产生背景及技术细节

KV Cache 简介

KV Cache 是大模型推理性能优化的一个常用技术,该技术可以在不影响任何计算精度的前提下,通过空间换时间的思想,提高推理性能。

KV Cache 诞生的背景

对于仅编码器Transformer架构的模型的推理,我们给一个输入文本,模型会输出一个回答(长度为 N),其实该过程中执行了 N 次推理过程。即类 GPT 的仅编码器模型一次推理只输出一个token,输出的 token 会与输入 tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

针对一个仅编码器Transformer架构的模型,假设用户输入为“recite the first law”,模型续写得到的输出为“A robot may not ”,模型的生成过程如下:

  1. 将“ecite the first law”输入模型,得到每个token的注意力表示。使用“law”的注意力表示,预测得到下一个token为“A”(实际还需要将该注意力表示映射成概率分布logits,为了方便叙述,我们忽略该步骤)。
  2. 将“A”拼接到原来的输入,得到“recite the first law A”,将其输入模型,得到注意力表示,使用“A”的注意力表示,预测得到下一个token为“robot”。
  3. 将“robot”拼接到原来的输入,依此类推,预测得到“robot”,最终得到“recite the first law A robot may not”
大模型KV Cache产生背景及技术细节

仅编码器Transformer架构的自回归模型为带 Masked 的 Self Attention。因此,在没有KV Cache的情况下,其计算过程如下所示。

大模型KV Cache产生背景及技术细节

正常情况下,Attention的计算公式如下:

大模型KV Cache产生背景及技术细节

为了看上去方便,我们暂时忽略scale项,因此,Attention的计算公式如下所示(softmaxed 表示已经按行进行了softmax):

大模型KV Cache产生背景及技术细节
大模型KV Cache产生背景及技术细节

变为矩阵时,softmax 会针对行进行计算,详细如下(softmaxed 表示已经按行进行了softmax):

大模型KV Cache产生背景及技术细节

其中,表示 Attention 的第一行, 表示 Attention 的第二行。

大模型KV Cache产生背景及技术细节

对于,由于这个值会mask掉,你会发现, 在第二步参与的计算与第一步是完全一样的,并且 参与计算Attention时也仅仅依赖于 ,与 毫无关系。

对于 参与计算Attention时也仅仅依赖于 ,与 毫无关系。

大模型KV Cache产生背景及技术细节

其计算方式如 Step2 所示。

大模型KV Cache产生背景及技术细节
大模型KV Cache产生背景及技术细节

其计算方式如 Step2 所示。

大模型KV Cache产生背景及技术细节

对于 参与计算Attention时也仅仅依赖于 

看上面图和公式,我们可以得出以下结论:

  1. 当前计算方式存在大量冗余计算,每一次生成新的Token都需要计算之前的KV。
  2. 的计算过程中,主要与 有关。 参与计算Attention时也仅仅依赖于 
  3. 每一步中,其实只需要根据 计算  就可以,之前已经计算的Attention完全不需要重新计算。但是 K 和 V 是全程参与计算的,所以这里我们需要把每一步的 K 、 V 缓存起来。

KV Cache 步骤

正是因为 Self Attention 中带 Masked ,因此,在推理的时候,前面已经生成的 Token 不需要与后面的 Token 产生 Attention ,从而使得前面已经计算的 K 和 V 可以缓存起来。

一个典型的带有 KV cache 优化的生成大模型的推理过程包含了两个阶段:

1. 预填充阶段:输入一个prompt序列,为每个transformer层生成 key cache 和 value cache(KV cache)。

2. 解码阶段:使用并更新KV cache,一个接一个地生成token,当前生成的token词依赖于之前已经生成的token。

预填充阶段计算过程如下:

大模型KV Cache产生背景及技术细节

解码阶段计算过程如下:

大模型KV Cache产生背景及技术细节

使不使用 KV Cache 的对比

下图展示了使用KV Cache和不使用KV Cache的对比,其中,紫色部分表示从缓存获取,灰色部分表示会被Masked。

大模型KV Cache产生背景及技术细节
大模型KV Cache产生背景及技术细节
大模型KV Cache产生背景及技术细节
大模型KV Cache产生背景及技术细节

下面使用 transformers 来比较有 KV Cache 和没有 KV Cache的情况下,GPT-2的生成速度。

import numpy as np  
import time  
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer  
  
device = "cuda" if torch.cuda.is_available() else "cpu"  
tokenizer = AutoTokenizer.from_pretrained("gpt2")  
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)  
  
for use_cache in (TrueFalse):  
    times = []  
    for _ in range(10): # measuring 10 generations  
        start = time.time()  
        model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000)  
        times.append(time.time() - start)  
    print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

运行结果:

  • 使用 KV caching: 11.885 +- 0.272 秒
  • 不使用 KV caching: 56.197 +- 1.855 秒

可以看到使不使用 KV cache 推理性能果差异显存。

使用 KV Cache 解码阶段计算量分析

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小。
如何计算矩阵乘法的FLOPs呢?
对于  ,计算 ?? 需要进行 ? 次乘法运算和 ? 次加法运算,共计 2? 次浮点数运算,需要  的FLOPs。对于  ,计算 ?? 需要的浮点数运算次数为  。

下面来看看在一个 Token 生成过程中一层 Transformer 的计算量。

首先,分析 self-attention 块的计算,计算公式如下:

我们来看看不使用 KV Cache 时,假设输入数据的形状为 [b, s],隐藏层维度为 h,则输入的形状为 [b, s, h]。self-attention块的计算如下:

  1. 计算 ?,?,? :矩阵乘法的输入和输出形状为 [?,?,ℎ]×[ℎ,ℎ]→[?,?,ℎ] 。计算量为 $ 3* bs2hh = 3∗2??ℎ^2=6??ℎ^2$ 。
  2.  矩阵乘法的输入和输出形状为 [?,ℎ???_???, ?, ???_ℎ???_ℎ?????_????]×[?, ℎ???_???, ???_ℎ???_ℎ?????_????, ?]→[?, ℎ???_???, ?, ?],计算量为 
  3. 计算在 ? 上的加权 ?????⋅? ,矩阵乘法的输入和输出形状为 [?,ℎ???_???,?,?]×[?,ℎ???_???,?,???_ℎ???_ℎ?????_????]→[?,ℎ???_???,?,???_ℎ???_ℎ?????_????] 。计算量为  。
  4. attention后的线性映射,矩阵乘法的输入和输出形状为 [?,?,ℎ]×[ℎ,ℎ]→[?,?,ℎ] 。计算量为  。

不使用 KV Cache 时,输入的形状为 [b, 1, h ],kv cache中含有 个 past word。self-attention块的计算如下:

  1. 计算  :矩阵乘法的输入和输出形状为 [?, 1, ℎ]×[ℎ, ℎ]→[?, 1, ℎ] 。计算量为  。
  2.  矩阵乘法的输入和输出形状为 [b, head_num, 1, per_head_hidden_size]×[b, head_num, per_head_hidden_size, kv_length+1]→[b, head_num, 1, kv_length+1] 。计算量为  。
  3. 计算在?上的加权 ?????·? ,矩阵乘法的输入和输出形状为 [b, head_num, 1, kv_length+1]×[b,head_num,kv_length+1,per_head_hidden_size]→[b,head_num,1,per_head_hidden_size] 。计算量为  。
  4. attention后的线性映射,矩阵乘法的输入和输出形状为 [?,1,ℎ]×[ℎ,ℎ]→[?,1,ℎ] 。计算量为  。

接下来分析MLP块的计算,计算公式如下:

不使用 KV Cache 时:

  1. 第一个线性层,矩阵乘法的输入和输出形状为 [?,?,ℎ]×[ℎ,4ℎ]→[?,?,4ℎ] 。计算量为  。
  2. 第二个线性层,矩阵乘法的输入和输出形状为 [?,?,4ℎ]×[4ℎ,ℎ]→[?,?,ℎ] 。计算量为 

使用 KV Cache 时:

  1. 第一个线性层,矩阵乘法的输入和输出形状为 [?, 1, ℎ]×[ℎ, 4ℎ]→[?,1,4ℎ] 。计算量为  。
  2. 第二个线性层,矩阵乘法的输入和输出形状为 [?, 1, 4ℎ]×[4ℎ, ℎ]→[?,1,ℎ] 。计算量为  。

将上述self-attention块和MLP块计算量相加,得到:

  • 采用kv cache时,得到每个transformer层的计算量大约为  。
  • 不采用kv cache时,得到每个transformer层的计算量大约为: 

此外,另一个计算量的大头是logits的计算,将隐藏向量映射为词表大小。

  • 采用kv cache时,矩阵乘法的输入和输出形状为 [?,1,ℎ]×[ℎ,?]→[?,1,?] ,计算量为  。
  • 不采用kv cache时为,矩阵乘法的输入和输出形状为 [?,?,ℎ]×[ℎ,?]→[?,?,?] ,计算量为  。

KV Cache 显存占用分析

假设输入序列的长度为s ,输出序列的长度为n ,transformer层数为l,隐藏层维度 h,KV Cache 存储 kv_seq_len 个 KV value,形状为 [b, head_num, kv_seq_len, head_dim], 峰值kv_seq_len为 s+n ,以float16来保存KV cache,那么KV cache的峰值显存占用大小为 b(s+n)hl2*2=4blh(s+n) 。这里第一个 2 表示 K/V cache,第二个2表示float16占 2 个 bytes。

以GPT3-175B为例,对比KV cache与模型参数占用显存的大小。模型配置如下:

模型名 参数量 层数 隐藏维度 注意力头数
GPT3 175B 96 12288 96

GPT3 模型占用显存大小为350GB。假设批次大小b=64 ,输入序列长度s=512 ,输出序列长度n=32 ,则KV cache 峰值占用显存为 4blh(s+n) = 164,282,499,072 bytes ≈ 164 ?? ,大约是模型参数显存的0.5倍。

KV Cache 存在的问题以及优化措施

当将LLMs应用于无限输入流时,使用原始的 Dense Attention 会出现两个主要挑战:

  • 上下文越长,那么矩阵占用的内存也会越多,不仅如此还会增加Decoder时候的延迟。
  • 现有模型的长度外推能力有限,即当序列长度超出预训练期间设置的attention窗口大小时,其性能会下降。

因此,目前提出了一些优化方法,比如:使用滑动窗口的注意力机制,主要有如下几种方式。

  • 一种方式是如下图 B 的窗口注意力(Window Attention):只缓存最近的 L 个 Token 的 KV。虽然推理效率很高,但一旦起始Token的键和值被驱逐,性能就会急剧下降。
  • 一种方式是下图 C 的滑动窗口重计算(Sliding Window w/ Re-computation):根据每个新 Token 的 L 个最近 Token 重建 KV 状态。虽然它在长文本上表现良好,但其 的复杂性(源于上下文重新计算中的二次注意力)使其相当慢。
大模型KV Cache产生背景及技术细节
  • 还有一种方式是StreamingLLM,在当前滑动窗口方法的基础上,重新引入了一些最初的 tokens 的KV在注意力计算中使用。StreamingLLM 中的KV缓存可以概念上分为两部分,如下图所示:(1)attention sink 是 4 个最初的 tokens,稳定了注意力计算;(2)Rolling KV Cache 保留了最近的token,这个窗口值是固定的。此外,还需要有些小改动来给attention注入位置信息,StreamingLLM就可以无缝地融入任何使用相对位置编码的自回归语言模型,如RoPE和ALiBi。
大模型KV Cache产生背景及技术细节

KV Cache 源码分析

GPT2 中 KV Cache 代码实现:

class GPT2Attention(nn.Module):
    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    )
 -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:

        ...
  
        # 拆分 Q、K、V
        query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
        
        ...
        
        # [batch, sequence_len, embeded_dim] -> [batch, heads, sequence_len, head_dim]
        query = self._split_heads(query, self.num_heads, self.head_dim) # 当前token对应的query
        key = self._split_heads(key, self.num_heads, self.head_dim) # 当前token对应的key
        value = self._split_heads(value, self.num_heads, self.head_dim) # 当前token对应的value

        ##################################
        # KV Cache 核心代码逻辑
        if layer_past is not None
            past_key, past_value = layer_past # 从 KV Cache 去数据
            key = torch.cat((past_key, key), dim=-2# 将当前token的key与历史的K拼接
            value = torch.cat((past_value, value), dim=-2# 将当前token的value与历史的V拼接

        if use_cache is True:
            present = (key, value) # 将数据存到 KV Cache
        else:
            present = None
        ##################################
        ...
        
        # 使用当前token的query与K和V计算注意力表示
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) # 返回att输出(激活)和权重

        # 合并多头注意力
        # attn_output: [batch, heads, sequence_len, head_dim] -> [batch, heads, embed_dim]
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)

Baichuan2 中 KV Cache 代码实现:

class Attention(nn.Module):

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
    )
 -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0-2).squeeze(-2)
        query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(12)
        key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(12)
        value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(12)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # [bsz, nh, t, hd]

        if past_key_value is not None:
            # 取出 KV Cache 中的值
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        
        # 保存 KV Cache 中的值
        past_key_value = (key_states, value_states) if use_cache else None

Huggingface Transformer 库中 LLaMA 中 KV Cache 代码实现:

class LlamaAttention(nn.Module):
    ...
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    )
 -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

    
        ...

        past_key_value = getattr(self, "past_key_value", past_key_value)
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            # 将当前 Token 的 kv 值更新到 KV Cache,并返回新的 KV
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        ...

        return attn_output, attn_weights, past_key_value

Huggingface Transformer 库中对Cache进行了抽象,里面实现了各种Cache,如:生成模型默认的动态缓存DynamicCache、StaticCache 和 StreamingLLM 论文中提到的SinkCache。


@dataclass
class Cache:
    """
    所有Cache的基础抽象类。实际数据结构由每个子类决定。
    """


    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    )
 -> Tuple[torch.Tensor, torch.Tensor]:

        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """

        raise NotImplementedError("Make sure to implement `update` in a subclass.")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")

    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        """Given the sequence length of the new inputs, returns the usable length of the cache."""
        # Cache without size limit -> all cache is usable
        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
        #   length, we will need to evict part of the cache (and thus not all cache is usable)
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx)
        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        return previous_seq_length

    @property
    def seen_tokens(self):
        logger.warning_once(
            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
            "model input instead."
        )
        if hasattr(self, "_seen_tokens"):
            return self._seen_tokens
        else:
            return None


class DynamicCache(Cache):
    # 随着生成更多 Token 而动态增长的Cache。这是生成模型的默认设置。
    # 它将键和值状态存储为张量列表,每层一个张量。每个张量的期望形状是
    # [batch_size, num_heads, seq_len, head_dim]。
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    )
 -> Tuple[torch.Tensor, torch.Tensor]:

        
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
        
        
class StaticCache(Cache):
    """
    与 torch.compile(model) 一起使用的静态 Cache 类
    """

    ...
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    )
 -> Tuple[torch.Tensor, torch.Tensor]:

        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        使用张量进行索引是非常重要的,否则你会向设备引入一个副本。
        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for. Kept for backward compatibility
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
                to know how much of the cache it should overwrite.

        Return:
            A tuple containing the updated key and value states.
        """

        new_cache_positions = cache_kwargs.get("cache_position")
        k_out = self.key_cache
        v_out = self.value_cache

        k_out[:, :, new_cache_positions] = key_states
        v_out[:, :, new_cache_positions] = value_states

        return k_out, v_out
    
class SinkCache(Cache):
    """
    # 正如[Attention Sinks 论文](https://arxiv.org/abs/2309.17453)中所描述的缓存。
    # 它允许模型生成超出其上下文窗口的长度,而不会失去会话的流畅性。
    # 因为它抛弃了过去tokens,模型将失去生成依赖于被丢弃的上下文的tokens的能力。
    # 它将键和值状态存储为张量列表,每层一个张量。每个张量的期望形状是
    # [batch_size, num_heads, seq_len, head_dim]
    """

    ...
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    )
 -> Tuple[torch.Tensor, torch.Tensor]:

        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
        # with partially rotated position embeddings, like Phi or Persimmon.
        sin = cache_kwargs.get("sin")
        cos = cache_kwargs.get("cos")
        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
        using_rope = cos is not None and sin is not None

        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # [bsz, num_heads, seq_len, head_dim]
        if len(self.key_cache) <= layer_idx:
            # Empty cache
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)

        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
            # Growing cache
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        else:
            # Shifting cache
            keys_to_keep = self.key_cache[layer_idx][
                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
            ]

            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
            if using_rope:
                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
                    key_states, cos[: self.window_length], sin[: self.window_length]
                )
                if partial_rotation_size is not None:
                    keys_to_keep, keys_pass = (
                        keys_to_keep[..., :partial_rotation_size],
                        keys_to_keep[..., partial_rotation_size:],
                    )
                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
                if partial_rotation_size is not None:
                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)

            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)

            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
            values_to_keep = self.value_cache[layer_idx][
                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
            ]
            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    

从 GPT2 、 Baichuan2 和 LLaMA 的源码中可以看到 KV Cache 核心代码的实现就几行并不复杂,但是带来的收益却挺大。

结语

本文简要分析了 KV Cache 原理、源码以及计算量和显存占用,这是一种典型的通过空间换时间(计算)的技术,虽然并不复杂,但是现在基本上是仅编码器Transformer架构生成大语言模型必备优化技术。

前沿技术新闻资讯

一篇大模型Agent工具使用全面研究综述

2025-1-22 20:07:22

前沿技术新闻资讯

通义星尘个性化大模型相关技术与应用

2025-1-22 22:02:52

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
购物车
优惠劵
搜索