GPU 显存优化深度指南:KV Cache 管理与实践
AI InfraGPU显存优化KV Cache
GPU 显存是大模型推理的核心瓶颈。本文深入解析显存优化技术,让你的 GPU 跑得更快、更省显存。
推理显存去哪了?
典型显存占用
| 组件 | 7B 模型 (FP16) | 13B 模型 (FP16) | 70B 模型 (FP16) |
|---|---|---|---|
| 模型权重 | 14GB | 26GB | 140GB |
| KV Cache | 动态 (1-16GB) | 动态 (2-32GB) | 动态 (10-80GB) |
| 激活值 | 1-2GB | 2-4GB | 8-16GB |
| 其他 | 1GB | 1GB | 2GB |
结论:KV Cache 是显存的最大变数!
KV Cache 详解
什么是 KV Cache?
Transformer 自回归生成时,需要重复计算之前所有 token 的 Key 和 Value:
# 无 Cache:每次都重新计算(O(n²) 复杂度)
def forward_no_cache(q, k, v):
for i in range(len(k)):
output += attention(q, k[i], v[i]) # 重复计算!
return output
# 有 Cache:只计算新 token(O(n) 复杂度)
def forward_with_cache(q, k_cache, v_cache, new_k, new_v):
k_cache.append(new_k) # 追加新 token 的 K
v_cache.append(new_v) # 追加新 token 的 V
return attention(q, k_cache, v_cache)
显存计算
KV Cache 大小公式:
KV Cache = 2 × batch_size × seq_len × num_layers × hidden_size × dtype_size
以 LLaMA-2-7B 为例:
- num_layers = 32
- hidden_size = 4096
- seq_len = 4096
- batch_size = 1
- FP16 = 2 bytes
Cache = 2 × 1 × 4096 × 32 × 4096 × 2
= 2 GB per sequence
显存优化技术
1. PagedAttention (vLLM)
将 KV Cache 分页管理,按需分配:
# vLLM 配置
from vllm import EngineArgs, AsyncLLMEngine
engine_args = EngineArgs(
model="meta-llama/Llama-2-7b-hf",
max_num_seqs=256, # 最大并发序列数
max_model_len=4096, # 最大序列长度
block_size=16, # Block 大小
gpu_memory_utilization=0.9, # 显存利用率
)
效果:并发提升 2-4 倍
2. Flash Attention
将注意力计算优化为 IO 密集型:
# 使用 Flash Attention
from flash_attn import FlashAttention
attn = FlashAttention()
output = attn(q, k, v, causal=True)
原理:
- 避免 Materialize 大型中间矩阵
- 将显存复杂度从 O(N²) 降到 O(N)
- 利用 GPU 带宽进行tiled 计算
效果:
- 显存减少 5-10 倍
- 速度提升 2-3 倍
3. KV Cache 量化
将 KV Cache 从 FP16 量化到 INT8/INT4:
# vLLM 启用 KV Cache 量化
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
kv_cache_dtype="fp8", # KV Cache 量化
)
效果:
| 量化方式 | KV Cache 显存 | 精度损失 |
|---|---|---|
| FP16 | 100% (baseline) | 0% |
| FP8 | 50% | <0.5% |
| INT8 | 37.5% | <1% |
| INT4 | 18.75% | <2% |
4. MQA / GQA
多头注意力优化:
MHA (Multi-Head Attention):
K: num_heads × head_dim
V: num_heads × head_dim
参数: num_heads × head_dim × 3
GQA (Grouped-Query Attention):
K: num_groups × head_dim
V: num_groups × head_dim
参数: num_groups × head_dim × 3 (G < num_heads)
MQA (Multi-Query Attention):
K: 1 × head_dim
V: 1 × head_dim
参数: head_dim × 3 (最少)
# 使用 GQA 的模型
# LLaMA 2 70B 使用 GQA (num_groups=8)
# 大幅减少 KV Cache 显存
5. Prefix Caching
前缀复用:
# SGLang 启用前缀缓存
from sglang import sgl, sgl_gen
@sgl
def chat_system():
sgl"""
You are a helpful AI assistant.
Current date: 2024-03-23
"""
# 多个请求共享 system prompt
# System prompt 的 KV Cache 完全复用
实战:显存计算与配置
场景:8 张 A100 80GB 部署 LLaMA-2-70B
from vllm import EngineArgs, AsyncLLMEngine
# 计算显存需求
# 模型权重: 140GB (FP16)
# 显存总量: 8 × 80GB = 640GB
# 留足 KV Cache 空间
engine_args = EngineArgs(
model="meta-llama/Llama-2-70b-hf",
tensor_parallel_size=4, # 4 卡并行
pipeline_parallel_size=2, # 2 路流水线
max_model_len=4096,
max_num_seqs=64, # 最大并发
gpu_memory_utilization=0.85, # 85% 利用率
kv_cache_dtype="fp8", # KV 量化
enforce_eager=False, # CUDA Graph 优化
)
实际性能:
- 吞吐量:~40 tok/s
- 并发:64 请求
- 显存:~500GB / 640GB
显存不足怎么办?
方案对比
| 方案 | 改动难度 | 效果 | 适用场景 |
|---|---|---|---|
| 量化模型 | 低 | 显存 -70% | 通用 |
| KV Cache 量化 | 低 | 显存 -50% | 生产部署 |
| 减少 max_model_len | 低 | 显存线性减少 | 短文本场景 |
| GQA/MQA | 高(换模型) | 显存 -50%+ | 新项目 |
| 多卡并行 | 中 | 显存线性减少 | 大模型 |
快速解决方案
# 1. 使用量化模型(推荐)
# INT4 量化,7B 模型只需 4GB 显存
ollama run llama2:7b-q4_0
# 2. vLLM 量化部署
lmdeploy lite quantize model --w-bit 4
# 3. 减少序列长度
# max_model_len=2048 显存减半
监控显存
# Python 监控显存
import torch
def get_gpu_memory():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
}
# 定期打印
import time
while True:
print(get_gpu_memory())
time.sleep(1)
总结
显存优化的核心思路:
- 减少模型权重:量化(INT4)
- 减少 KV Cache:PagedAttention + 量化
- 减少激活值:Flash Attention
- 并行扩展:TP/PP 多卡
| 技术 | 显存减少 | 速度影响 |
|---|---|---|
| PagedAttention | 30-50% | 提升 |
| KV 量化 (FP8) | 50% | 提升 |
| Flash Attention | 5-10x | 提升 2-3x |
| GQA | 50%+ | 略降 |
| INT4 权重 | 75% | 略降 |
掌握这些技术,你就能高效利用每一 GB 显存!
下期预告:MLC-LLM 端侧部署实践