GPU 显存优化深度指南:KV Cache 管理与实践

AI InfraGPU显存优化KV Cache

GPU 显存是大模型推理的核心瓶颈。本文深入解析显存优化技术,让你的 GPU 跑得更快、更省显存。

推理显存去哪了?

典型显存占用

组件7B 模型 (FP16)13B 模型 (FP16)70B 模型 (FP16)
模型权重14GB26GB140GB
KV Cache动态 (1-16GB)动态 (2-32GB)动态 (10-80GB)
激活值1-2GB2-4GB8-16GB
其他1GB1GB2GB

结论:KV Cache 是显存的最大变数!


KV Cache 详解

什么是 KV Cache?

Transformer 自回归生成时,需要重复计算之前所有 token 的 Key 和 Value:

# 无 Cache:每次都重新计算(O(n²) 复杂度)
def forward_no_cache(q, k, v):
    for i in range(len(k)):
        output += attention(q, k[i], v[i])  # 重复计算!
    return output

# 有 Cache:只计算新 token(O(n) 复杂度)
def forward_with_cache(q, k_cache, v_cache, new_k, new_v):
    k_cache.append(new_k)  # 追加新 token 的 K
    v_cache.append(new_v)  # 追加新 token 的 V
    return attention(q, k_cache, v_cache)

显存计算

KV Cache 大小公式:

KV Cache = 2 × batch_size × seq_len × num_layers × hidden_size × dtype_size

以 LLaMA-2-7B 为例:

  • num_layers = 32
  • hidden_size = 4096
  • seq_len = 4096
  • batch_size = 1
  • FP16 = 2 bytes
Cache = 2 × 1 × 4096 × 32 × 4096 × 2
      = 2 GB per sequence

显存优化技术

1. PagedAttention (vLLM)

将 KV Cache 分页管理,按需分配:

# vLLM 配置
from vllm import EngineArgs, AsyncLLMEngine

engine_args = EngineArgs(
    model="meta-llama/Llama-2-7b-hf",
    max_num_seqs=256,      # 最大并发序列数
    max_model_len=4096,    # 最大序列长度
    block_size=16,         # Block 大小
    gpu_memory_utilization=0.9,  # 显存利用率
)

效果:并发提升 2-4 倍


2. Flash Attention

将注意力计算优化为 IO 密集型:

# 使用 Flash Attention
from flash_attn import FlashAttention

attn = FlashAttention()
output = attn(q, k, v, causal=True)

原理:

  • 避免 Materialize 大型中间矩阵
  • 将显存复杂度从 O(N²) 降到 O(N)
  • 利用 GPU 带宽进行tiled 计算

效果:

  • 显存减少 5-10 倍
  • 速度提升 2-3 倍

3. KV Cache 量化

将 KV Cache 从 FP16 量化到 INT8/INT4:

# vLLM 启用 KV Cache 量化
from vllm import LLM

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    kv_cache_dtype="fp8",  # KV Cache 量化
)

效果:

量化方式KV Cache 显存精度损失
FP16100% (baseline)0%
FP850%<0.5%
INT837.5%<1%
INT418.75%<2%

4. MQA / GQA

多头注意力优化:

MHA (Multi-Head Attention):
  K: num_heads × head_dim
  V: num_heads × head_dim
  参数: num_heads × head_dim × 3

GQA (Grouped-Query Attention):
  K: num_groups × head_dim
  V: num_groups × head_dim
  参数: num_groups × head_dim × 3 (G < num_heads)

MQA (Multi-Query Attention):
  K: 1 × head_dim
  V: 1 × head_dim
  参数: head_dim × 3 (最少)
# 使用 GQA 的模型
# LLaMA 2 70B 使用 GQA (num_groups=8)
# 大幅减少 KV Cache 显存

5. Prefix Caching

前缀复用:

# SGLang 启用前缀缓存
from sglang import sgl, sgl_gen

@sgl
def chat_system():
    sgl"""
    You are a helpful AI assistant.
    Current date: 2024-03-23
    """

# 多个请求共享 system prompt
# System prompt 的 KV Cache 完全复用

实战:显存计算与配置

场景:8 张 A100 80GB 部署 LLaMA-2-70B

from vllm import EngineArgs, AsyncLLMEngine

# 计算显存需求
# 模型权重: 140GB (FP16)
# 显存总量: 8 × 80GB = 640GB
# 留足 KV Cache 空间

engine_args = EngineArgs(
    model="meta-llama/Llama-2-70b-hf",
    tensor_parallel_size=4,        # 4 卡并行
    pipeline_parallel_size=2,      # 2 路流水线
    max_model_len=4096,
    max_num_seqs=64,               # 最大并发
    gpu_memory_utilization=0.85,   # 85% 利用率
    kv_cache_dtype="fp8",          # KV 量化
    enforce_eager=False,           # CUDA Graph 优化
)

实际性能:

  • 吞吐量:~40 tok/s
  • 并发:64 请求
  • 显存:~500GB / 640GB

显存不足怎么办?

方案对比

方案改动难度效果适用场景
量化模型显存 -70%通用
KV Cache 量化显存 -50%生产部署
减少 max_model_len显存线性减少短文本场景
GQA/MQA高(换模型)显存 -50%+新项目
多卡并行显存线性减少大模型

快速解决方案

# 1. 使用量化模型(推荐)
# INT4 量化,7B 模型只需 4GB 显存
ollama run llama2:7b-q4_0

# 2. vLLM 量化部署
lmdeploy lite quantize model --w-bit 4

# 3. 减少序列长度
# max_model_len=2048 显存减半

监控显存

# Python 监控显存
import torch

def get_gpu_memory():
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
    }

# 定期打印
import time
while True:
    print(get_gpu_memory())
    time.sleep(1)

总结

显存优化的核心思路:

  1. 减少模型权重:量化(INT4)
  2. 减少 KV Cache:PagedAttention + 量化
  3. 减少激活值:Flash Attention
  4. 并行扩展:TP/PP 多卡
技术显存减少速度影响
PagedAttention30-50%提升
KV 量化 (FP8)50%提升
Flash Attention5-10x提升 2-3x
GQA50%+略降
INT4 权重75%略降

掌握这些技术,你就能高效利用每一 GB 显存!


下期预告:MLC-LLM 端侧部署实践