ZeRO论文简读
之前看了李沐老师的ZeRO和参数服务器的讲解讲解,没有太看懂。
重新自己学一下。
另外今天gpt5出了,好好问问他。
ZeRO的出发点是:百亿、万亿参数大模型的训练瓶颈不是算力,而是显存。
即使多卡训练,单卡的显存限制会导致模型根本放不下。
1 前置知识
1.1 数据并行(Data Parallel, DP)
每张卡:
- 一份完整的模型参数(weights)
- 一份完整的优化器状态(例如 Adam 的动量、方差)
- 一份完整的梯度(在反向传播时生成)
怎么并行:
- 不同 GPU 各自拿不同的 mini-batch 数据,独立做前向和反向。
- 反向结束后,通过 all-reduce 把各自的梯度求平均(同步)。
- 每张卡用平均后的梯度更新自己那份完整参数。
特点:
- 算法简单,计算效率高。
- 缺点:显存占用大,冗余严重——每张卡都放全套参数+状态。
简单说:每张卡并行计算不同的batch,每一轮结束后求平均(同步)
这很显然会有一些同步开销。有几个常用的做法:
-
通信与计算重叠
反向传播中,算出一层的梯度后就开始通信。把通信和计算并行。
-
参数服务器
参数存在服务器里,谁先完成谁更新参数。但是会引入“梯度延迟”,获取到的参数并不是最新的,影响收敛质量。
-
负载均衡、流水优化等:减小短板效应,让各个卡的用时尽量一致。
1.2 模型并行(Model Parallel, MP)
每张卡:
- 只放部分参数(例如某一层的一部分权重,或某个矩阵的分块)
- 对应部分的中间激活(activations)
怎么并行:
- 前向传播时,不同 GPU 负责不同的计算块,计算结果需要互相通信。
- 反向传播时,同理,需要通信传递梯度。
特点:
- 显存压力小(因为每张卡只存部分模型)。
- 缺点:计算粒度小、通信频繁,跨节点时性能下降明显。
每个卡存模型的一部分。
前向传播得从模型的第一个部分开始计算,将计算结果发送给下一张卡。
反向传播从最后一个部分开始计算梯度,将结果传回上一张卡。
模型并行并没有加速,频繁的通信显然降低了计算效率。
2 ZeRO
根据前置知识,现有方案要么增加显存浪费来提升计算效率,要么降低计算效率减少显存浪费。
所以ZeRO的目的就是:在不降低计算/通信效率的前提下,极大减少显存占用,让模型规模可以按设备数量线性扩展。
2.1 显存占用的构成
显存消耗分为两大类:模型状态(Model States) 和 剩余状态(Residual States)。
2.1.1 模型状态(Model States)
这是训练时最大的显存消耗来源,主要包括以下几部分:
- 参数(Parameters)
- 用来存储模型权重。
- 在混合精度下,通常有 fp16 副本(参与前向/反向计算)和 fp32 副本(优化器更新用)。
- 约占:6Ψ 字节(2Ψ fp16 + 4Ψ fp32)。
- 梯度(Gradients)
- 反向传播时产生。
- 通常是 fp16 格式,占用约 2Ψ 字节。
- 优化器状态(Optimizer States)
- 例如 Adam 要存两份额外的张量:动量(momentum)和方差(variance)。
- 通常是 fp32,每份 4Ψ 字节,总共 8Ψ 字节。
- 这是显存中最大的一块。
👉 合计下来,在 混合精度 Adam 下,单卡需要大约 16Ψ 字节,远大于仅存权重(fp16 参数)的需求。
举例:GPT-2(15 亿参数,Ψ=1.5B) → 约 24 GB 显存。
2.1.2 剩余状态(Residual States)
在模型状态之外,还存在三类“隐形开销”:
-
激活(Activations)
-
前向传播时保存,以便反向传播使用。
-
占用量与 batch size × 序列长度 × 隐藏维度 × 层数 成正比。
-
例如:GPT-2(1.5B 参数,batch=32,seq=1024) → 60 GB 激活显存,
使用激活检查点(activation checkpointing)可降到约 8 GB,但需要增加 33% 重算开销。
-
大模型(100B 参数)即使用了检查点,激活也可能仍需 60 GB。
-
-
临时缓冲区(Temporary Buffers)
- 用于梯度 all-reduce、归一化等操作时,把梯度拼成一个大 buffer 提升通信效率。
- 常常是 fp32 格式。
- 举例:1.5B 参数模型 → 单个临时缓冲区可达 6 GB。
-
显存碎片化(Memory Fragmentation)
- 由于不同生命周期张量(长存的 vs 临时的)交织分配,可能出现“碎片”。
- 结果:即使有 30% 以上显存空闲,也可能因找不到连续大块内存而 OOM。