ZeRO论文简读

| 笔记 | 1990 | 5分钟 | 训练加速系统AI论文

之前看了李沐老师的ZeRO和参数服务器的讲解讲解,没有太看懂。

重新自己学一下。

另外今天gpt5出了,好好问问他。

ZeRO的出发点是:百亿、万亿参数大模型的训练瓶颈不是算力,而是显存。

即使多卡训练,单卡的显存限制会导致模型根本放不下。

1 前置知识

1.1 数据并行(Data Parallel, DP)

每张卡

  • 一份完整的模型参数(weights)
  • 一份完整的优化器状态(例如 Adam 的动量、方差)
  • 一份完整的梯度(在反向传播时生成)

怎么并行

  1. 不同 GPU 各自拿不同的 mini-batch 数据,独立做前向和反向。
  2. 反向结束后,通过 all-reduce 把各自的梯度求平均(同步)。
  3. 每张卡用平均后的梯度更新自己那份完整参数。

特点

  • 算法简单,计算效率高。
  • 缺点:显存占用大,冗余严重——每张卡都放全套参数+状态。

简单说:每张卡并行计算不同的batch,每一轮结束后求平均(同步)

这很显然会有一些同步开销。有几个常用的做法:

  1. 通信与计算重叠

    反向传播中,算出一层的梯度后就开始通信。把通信和计算并行。

  2. 参数服务器

    参数存在服务器里,谁先完成谁更新参数。但是会引入“梯度延迟”,获取到的参数并不是最新的,影响收敛质量。

  3. 负载均衡、流水优化等:减小短板效应,让各个卡的用时尽量一致。

1.2 模型并行(Model Parallel, MP)

每张卡

  • 只放部分参数(例如某一层的一部分权重,或某个矩阵的分块)
  • 对应部分的中间激活(activations)

怎么并行

  1. 前向传播时,不同 GPU 负责不同的计算块,计算结果需要互相通信。
  2. 反向传播时,同理,需要通信传递梯度。

特点

  • 显存压力小(因为每张卡只存部分模型)。
  • 缺点:计算粒度小、通信频繁,跨节点时性能下降明显。

每个卡存模型的一部分。

前向传播得从模型的第一个部分开始计算,将计算结果发送给下一张卡。

反向传播从最后一个部分开始计算梯度,将结果传回上一张卡。

模型并行并没有加速,频繁的通信显然降低了计算效率。

2 ZeRO

根据前置知识,现有方案要么增加显存浪费来提升计算效率,要么降低计算效率减少显存浪费。

所以ZeRO的目的就是:在不降低计算/通信效率的前提下,极大减少显存占用,让模型规模可以按设备数量线性扩展。

2.1 显存占用的构成

显存消耗分为两大类:模型状态(Model States)剩余状态(Residual States)

2.1.1 模型状态(Model States)

这是训练时最大的显存消耗来源,主要包括以下几部分:

  1. 参数(Parameters)
    • 用来存储模型权重。
    • 在混合精度下,通常有 fp16 副本(参与前向/反向计算)和 fp32 副本(优化器更新用)。
    • 约占:6Ψ 字节(2Ψ fp16 + 4Ψ fp32)。
  2. 梯度(Gradients)
    • 反向传播时产生。
    • 通常是 fp16 格式,占用约 2Ψ 字节
  3. 优化器状态(Optimizer States)
    • 例如 Adam 要存两份额外的张量:动量(momentum)和方差(variance)。
    • 通常是 fp32,每份 4Ψ 字节,总共 8Ψ 字节
    • 这是显存中最大的一块。

👉 合计下来,在 混合精度 Adam 下,单卡需要大约 16Ψ 字节,远大于仅存权重(fp16 参数)的需求。

举例:GPT-2(15 亿参数,Ψ=1.5B) → 约 24 GB 显存

2.1.2 剩余状态(Residual States)

在模型状态之外,还存在三类“隐形开销”:

  1. 激活(Activations)

    • 前向传播时保存,以便反向传播使用。

    • 占用量与 batch size × 序列长度 × 隐藏维度 × 层数 成正比。

    • 例如:GPT-2(1.5B 参数,batch=32,seq=1024) → 60 GB 激活显存

      使用激活检查点(activation checkpointing)可降到约 8 GB,但需要增加 33% 重算开销。

    • 大模型(100B 参数)即使用了检查点,激活也可能仍需 60 GB

  2. 临时缓冲区(Temporary Buffers)

    • 用于梯度 all-reduce、归一化等操作时,把梯度拼成一个大 buffer 提升通信效率。
    • 常常是 fp32 格式。
    • 举例:1.5B 参数模型 → 单个临时缓冲区可达 6 GB
  3. 显存碎片化(Memory Fragmentation)

    • 由于不同生命周期张量(长存的 vs 临时的)交织分配,可能出现“碎片”。
    • 结果:即使有 30% 以上显存空闲,也可能因找不到连续大块内存而 OOM。