第 5 章 训练循环 · 下:逐行讲解


文档摘要

第 5 章 训练循环 · 下:逐行讲解 接上篇。本篇深挖四个重难点:余弦退火的数学、AdamW 的设计、断点续训的细节、训练循环里 9 个容易踩坑的细节。 5.8 学习率调度器:余弦退火(重难点) 工作原理 接收一个函数 ,每步把基础学习率乘以这个函数的返回值。所以这个函数返回的是「当前学习率相对峰值的比例」,范围 [0, 1]: 数学拆解 设 , , ,当前步 : 阶段 1:预热( ), ,从 0 线性升到 1。 是兜底:万一 warmupiters=0,除以 0 会崩,用 1 兜底变成 0/1=0。 阶段 2:退火( ),逐步推导: :退火进度 ,从 0 到 1。 :从 到 。 :从 到 。 :起点 ,终点 。 代入 :终点 lrratio = 0.1 ✓,与配置一致。

第 5 章 训练循环 · 下:逐行讲解

接上篇。本篇深挖四个重难点:余弦退火的数学、AdamW 的设计、断点续训的细节、训练循环里 9 个容易踩坑的细节。

5.8 学习率调度器:余弦退火(重难点)

工作原理

LambdaLR 接收一个函数 lr_lambda(step),每步把基础学习率乘以这个函数的返回值。所以这个函数返回的是「当前学习率相对峰值的比例」,范围 [0, 1]:

实际 lr = optimizer.lr × lr_lambda(step) = 3e-4 × lr_lambda(step)

数学拆解

w = warmup_itersT = max_itersr = min_lr_ratio,当前步 s

阶段 1:预热(s < wlr_ratio = s / w,从 0 线性升到 1。max(1, num_warmup_steps) 是兜底:万一 warmup_iters=0,除以 0 会崩,用 1 兜底变成 0/1=0。

阶段 2:退火(s ≥ w,逐步推导:

  • progress:退火进度 (s - w) / (T - w),从 0 到 1。
  • cos(π·progress):从 cos(0)=1cos(π)=-1
  • cosine = 0.5·(1 + cos(π·progress)):从 0.5·2=10.5·0=0
  • lr_ratio = r + (1-r)·cosine:起点 r + (1-r)·1 = 1,终点 r + (1-r)·0 = r

代入 r=0.1:终点 lr_ratio = 0.1 ✓,与配置一致。

曲线形状

lr_ratio ↑ 1 │ ╭───╮ ← 峰值 (s=w 时达到 1) │ ╱ ╲ │ ╱ ╲ │ ╱ ╲ 0.1 │ ╱ ╰───╯ ← 终点 (s=T 时降到 r) │╱ └─────────────────────→ step 0 w T

钳制保护

min_lr_ratio = max(0.0, min(1.0, min_lr_ratio))

min_lr_ratio 钳制到 [0, 1],防止用户传 1.5 或 -0.5 导致曲线变形。这是防御式编程的体现。

5.9 优化器:AdamW

optimizer = AdamW( model.parameters(), lr=train_config.learning_rate, # 3e-4 betas=train_config.betas, # (0.9, 0.95) weight_decay=train_config.weight_decay, # 0.1 )

为什么是 AdamW 而不是 Adam

  • Adam 把 weight decay 混在梯度里(L2 正则),与自适应学习率耦合,效果打折。
  • AdamW 把 weight decay 解耦:先按 Adam 更新参数,再独立地按 w ← w - lr·wd·w 衰减权重。

这是 GPT-2/3 训练的标准做法(论文 Decoupled Weight Decay Regularization 证明解耦版效果更好)。

betas=(0.9, 0.95) 的含义

Adam 维护梯度的一阶矩(动量)和二阶矩(梯度平方的指数移动平均)两个估计:

m_t = beta1 * m_{t-1} + (1 - beta1) * g_t # 一阶矩 v_t = beta2 * v_{t-1} + (1 - beta2) * g_t² # 二阶矩 update = lr * m_t / (sqrt(v_t) + eps)
  • beta1=0.9:一阶矩衰减系数(标准值)。
  • beta2=0.95:二阶矩衰减系数。注意比标准 Adam 的 0.999 小很多——这是 GPT 论文的经验值,让二阶矩更快跟随梯度变化,适合语言模型稀疏梯度的特点。

weight_decay=0.1 怎么作用

AdamW 的解耦权重衰减每步额外执行 w = w - lr · wd · w,把权重往 0 拉,防止过拟合。注意 bias 和 LayerNorm 的 gamma/beta 通常不衰减(框架实现已处理),因为它们不该被往 0 拉。

5.10 断点续训

start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_step = ckpt["step"] + 1

关键认知

恢复训练不是只加载模型权重!优化器内部有动量、二阶矩估计(每个参数对应一组状态),调度器有当前步数。三者必须一起恢复

只恢复 后果
模型 优化器动量归零,等于从头攒动量,前期训练抖动
不恢复调度器 lr 从预热重新开始,与已训练步数不匹配
不恢复 step 进度条从 0 开始,checkpoint 命名冲突

续训的典型场景

  1. 训练中断:停电、OOM、Ctrl-C。--resume checkpoints/gpt_step2000.pt 从 2001 步继续。
  2. 继续训更长:原本 5000 步不够,想加 5000 步。改 --max-iters 10000 --resume .../gpt_step5000.pt
  3. 换硬件:在 A100 上训到一半,转到另一台同架构卡上继续。

5.11 训练循环逐行拆解(核心)

model.train() # ★ 训练模式:dropout 生效 running_loss = 0.0 n_batches = len(dataloader) pbar = tqdm(total=train_config.max_iters, initial=start_step, desc="训练", unit="step") step = start_step data_iter = iter(dataloader) while step < train_config.max_iters: # ---- 取数据:epoch 用完自动重启 ---- try: x, y = next(data_iter) except StopIteration: data_iter = iter(dataloader) x, y = next(data_iter) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # ---- 前向 + 反向 ---- outputs = model(input_ids=x, labels=y) loss = outputs.loss optimizer.zero_grad(set_to_none=True) loss.backward() # ---- 梯度裁剪 ---- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip) # ---- 步进 ---- optimizer.step() scheduler.step() # ---- 日志 ---- running_loss += loss.item() step += 1 pbar.update(1) if step % train_config.log_iter == 0: avg_loss = running_loss / train_config.log_iter running_loss = 0.0 # ---- 定期保存 checkpoint ---- if step % train_config.save_iter == 0 or step == train_config.max_iters: save_checkpoint(...)

9 个细节逐个讲

model.train() vs model.eval():训练循环开头必须切 train(),dropout 才生效。一旦中间调过 eval()(比如做验证),回到训练时也要切回 train()。本项目无验证集,所以只在开头切一次。

try/except StopIteration 实现「无限数据」:DataLoader 迭代完一个 epoch 会抛 StopIteration,这里捕获后重建迭代器,实现「数据无限供应,按步数而非 epoch 控制训练」——这是 LM 训练的标准模式(语言模型数据量大,用步数更可控)。

non_blocking=True:配合 pin_memory=True(见第 3 章),让 CPU→GPU 数据拷贝异步进行。要小心:如果下一步计算立刻就要用这个数据,异步拷贝会被强制同步,效果归零。

model(input_ids=x, labels=y) 自动算 loss:传 labels 时框架会内部算交叉熵 loss。注意它内部对所有位置(除 pad)取平均,并做了「logits 去尾、labels 去头」的对齐(见第 4 章)。

zero_grad(set_to_none=True):PyTorch 2.0+ 推荐,把梯度置 None 而非填 0。好处:省内存、优化器检测到 grad is None 会跳过该参数的更新(如冻结层)、少一次遍历填 0。

clip_grad_norm_:防梯度爆炸:把所有参数梯度当作一个大向量,如果它的 L2 范数超过 max_norm(这里是 1.0),就按比例缩放到 max_norm。这是 Transformer 训练的「保险丝」。

total_norm = sqrt(sum(grad.norm()² for grad in all_grads)) if total_norm > max_norm: scale = max_norm / total_norm for grad in all_grads: grad *= scale

💡 注意是 clip_grad_norm_(按总范数),不是 clip_grad_value_(按每元素绝对值)。语言模型训练用前者,因为它保留梯度方向只缩放幅度。

optimizer.step() 在前、scheduler.step() 在后:顺序不能反。调度器要根据「这一步刚刚用过的 lr」记录状态。

loss.item() 会触发同步.item() 把 GPU 上的标量拷到 CPU,会强制 CPU/GPU 同步。频繁调用会拖慢训练。本项目每步都打进度条是为了观感,严格优化时可以改成每 N 步取一次。

⑨ 双触发条件的 checkpoint 保存if step % save_iter == 0 or step == max_itersor 后半句保证训练正常结束的最后一步一定保存,即使 max_iters 不是 save_iter 的整数倍。举例:max_iters=4999,没有 or step == max_iters 就不会保存最后一步。

5.12 动手实验

  1. 画学习率曲线:在训练循环里加一行每 100 步打印 scheduler.get_last_lr()[0],对照第 2 章描述的曲线是否一致。
  2. 断点续训验证:训练 100 步,记下 loss;中断,用 --resume 继续 100 步,对比与连续训 200 步的 loss 曲线是否吻合。
  3. 梯度裁剪对比:把 grad_clip 改成 0.01(过小),观察 loss 是否会变成 NaN 或停滞(梯度被过度压缩)。
  4. 优化器状态重要性:训练 200 步存 checkpoint;然后写个小脚本,加载 model_state_dict 不加载 optimizer_state_dict,继续训练,观察前几十步 loss 是否抖动更厉害。
  5. 思考题:为什么 optimizer.step() 必须在 scheduler.step() 之前?

5.13 小结(下篇)

  • 余弦退火 = 预热线性升温 + 余弦曲线平滑下降到 min_lr_ratioLambdaLR 用「相对峰值比例」函数控制。
  • AdamW 用 betas=(0.9, 0.95)(GPT 经验值,二阶矩更快跟随)+ weight_decay=0.1(解耦权重衰减)。
  • 断点续训必须同时恢复模型、优化器、调度器、step,缺一不可。
  • 训练循环 9 个细节:train 模式、StopIteration 重启、non_blocking、自动 loss、set_to_none、clip_grad_norm、step 顺序、loss.item 同步、双触发保存。

5.14 下一章

模型训完了,去《第 6 章 推理与采样》看怎么让它生成文本。


发布者: 作者: 转发
评论区 (0)
U