第 7 章 训练主循环


文档摘要

第 7 章 训练主循环 所有零件都备齐了——数据、教师、学生、损失。本章把它们组装成一个完整的训练循环,并实现本项目最有特色的几个工程能力:损失分解日志、师生对比快照、教师输出缓存、断点续训。 7.1 训练循环的全貌 本章按这个流程逐步拆解。 7.2 初始化:配置、种子、设备 训练入口的前三步是「打地基」: 7.2.1 随机种子:可复现性的基础 固定三个随机源(Python、NumPy、PyTorch),让相同配置 + 相同种子能复现完全一致的结果。这是做实验对比的前提——否则你无法判断效果差异来自参数还是来自随机性。 7.2.2 设备选择 默认 ,有 GPU 用 GPU,否则退回 CPU。 7.3 加载数据与模型 这两步分别在第 4、5 章讲过,这里只是组装。

第 7 章 训练主循环

所有零件都备齐了——数据、教师、学生、损失。本章把它们组装成一个完整的训练循环,并实现本项目最有特色的几个工程能力:损失分解日志、师生对比快照、教师输出缓存、断点续训。

7.1 训练循环的全貌

本章按这个流程逐步拆解。

7.2 初始化:配置、种子、设备

训练入口的前三步是「打地基」:

def train(): args = parse_args() # 解析命令行 cfg = DistillConfig() apply_args_to_config(args, cfg) # 覆盖默认配置 set_seed(cfg.seed) # 固定随机种子 device = select_device(cfg.device)

7.2.1 随机种子:可复现性的基础

import random, numpy as np, torch def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)

固定三个随机源(Python、NumPy、PyTorch),让相同配置 + 相同种子能复现完全一致的结果。这是做实验对比的前提——否则你无法判断效果差异来自参数还是来自随机性。

7.2.2 设备选择

def select_device(name): if name == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(name)

默认 auto,有 GPU 用 GPU,否则退回 CPU。

7.3 加载数据与模型

# 数据:训练集 + 验证集 text = get_dataset() train_loader, eval_loader, train_dataset = build_train_eval_loaders( text=text, block_size=cfg.block_size, batch_size=cfg.batch_size, num_workers=cfg.num_workers, eval_ratio=cfg.eval_ratio, ) # 教师(冻结)+ 学生(可训练) teacher = build_teacher(cfg) teacher.to(device) student = build_student(cfg) student.to(device)

这两步分别在第 4、5 章讲过,这里只是组装。注意教师和学生都要 .to(device),否则一个在 GPU 一个在 CPU 会导致数据搬运混乱。

7.4 优化器与学习率调度

7.4.1 AdamW 优化器

from torch.optim import AdamW optimizer = AdamW( student.parameters(), # 只优化学生! lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay, )

关键点:optimizer 只接收 student.parameters()。教师被排除在外,即使忘了冻结也不会被更新。这是「只训练学生」的最终保证。

7.4.2 余弦退火学习率调度

学习率曲线采用「线性预热 + 余弦退火」两段式:

import math from torch.optim.lr_scheduler import LambdaLR def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1): def lr_lambda(current_step): # 1) 预热:线性升温 if current_step < num_warmup_steps: return float(current_step) / max(1, num_warmup_steps) # 2) 余弦退火:平滑下降 progress = (current_step - num_warmup_steps) / max( 1, num_training_steps - num_warmup_steps ) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine return LambdaLR(optimizer, lr_lambda)

两个阶段的作用:

阶段 学习率变化 为什么
预热(warmup) 0 → 峰值 初始权重随机,大学习率会破坏训练;先小步试探
余弦退火 峰值 → min_lr 平滑下降,后期精调,避免在最优解附近震荡

这是 GPT-2/3 训练的标准配方,对蒸馏同样适用。

7.5 核心训练循环

这是整个项目的「心脏」:

student.train() running_total = running_ce = running_kd = 0.0 step = 0 data_iter = iter(train_loader) while step < cfg.max_iters: # 取一个批次(epoch 用尽则重启迭代器) try: x, y, idx = next(data_iter) except StopIteration: data_iter = iter(train_loader) x, y, idx = next(data_iter) x, y, idx = x.to(device), y.to(device), idx.to(device) # ① 教师前向(不计算梯度) if use_cache: teacher_logits = teacher_cache.fetch(idx, device) else: with torch.no_grad(): teacher_logits = teacher(input_ids=x).logits # ② 学生前向 student_logits = student(input_ids=x).logits # ③ 蒸馏损失(含 CE + KL 分解) total_loss, ce_loss, kd_loss = distillation_loss( student_logits, teacher_logits, y, cfg.temperature, cfg.alpha ) # ④ 反向传播 + 梯度裁剪 + 更新 optimizer.zero_grad(set_to_none=True) total_loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), cfg.grad_clip) optimizer.step() scheduler.step() # ⑤ 统计与日志(见 7.6) ...

逐步解释关键点:

7.5.1 教师前向必须 no_grad

with torch.no_grad(): teacher_logits = teacher(input_ids=x).logits

torch.no_grad() 让教师前向不构建计算图,既省显存又省算力。配合第 5 章的参数冻结,双保险确保教师不被更新。

7.5.2 set_to_none=True

optimizer.zero_grad(set_to_none=True)

把梯度设为 None 而非 0,省一次全零写入,略快一点。这是 PyTorch 推荐的现代写法。

7.5.3 梯度裁剪

torch.nn.utils.clip_grad_norm_(student.parameters(), cfg.grad_clip)

把所有梯度的整体范数裁剪到 grad_clip(默认 1.0)以内。训练初期学生权重随机,偶尔会出现梯度爆炸,裁剪能稳定训练。这是 Transformer 训练的标配。

7.5.4 无限迭代器模式

try: x, y, idx = next(data_iter) except StopIteration: data_iter = iter(train_loader) # epoch 用尽,重新开始 x, y, idx = next(data_iter)

本项目按「步数」(max_iters)而非「epoch」控制训练长度。当数据集迭代完一个 epoch,自动重启迭代器继续下一个 epoch,直到步数达标。这种模式在数据量大、想精确控制训练步数时很常用。

7.6 特色之一:损失分解日志

log_iter 步打印一次三项损失的均值:

running_total += total_loss.item() running_ce += ce_loss.item() running_kd += kd_loss.item() step += 1 if step % cfg.log_iter == 0: avg_total = running_total / cfg.log_iter avg_ce = running_ce / cfg.log_iter avg_kd = running_kd / cfg.log_iter print( f"[train] step {step}/{cfg.max_iters} | " f"total={avg_total:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | " f"lr={scheduler.get_last_lr()[0]:.2e}" ) running_total = running_ce = running_kd = 0.0

输出示例:

[train] step 10/5000 | total=5.43 ce=4.89 kd=5.97 | lr=2.95e-04

为什么要分三项打印?因为它们反映了蒸馏的不同侧面:

指标 反映 健康趋势
total 总优化目标 持续下降
ce 学生对真实标签的拟合 下降,但不能为 0(过拟合)
kd 学生对教师分布的模仿 下降,趋近 0 说明学得像

如果 ce 下降但 kd 不降,说明学生在学硬标签却没学教师,蒸馏没起作用——可能是温度太低或 α 太高。这种诊断能力,正是损失分解的价值。

7.7 特色之二:师生对比快照评估

eval_iter 步,在验证集上做一次快速评估,对比师生表现:

if cfg.eval_iter > 0 and step % cfg.eval_iter == 0: metrics = quick_eval(student, teacher, eval_loader, device) print( f"[eval] step {step} | " f"student ppl={metrics['student_ppl']:.2f} top1={metrics['student_top1']:.3f} | " f"teacher ppl={metrics['teacher_ppl']:.2f} top1={metrics['teacher_top1']:.3f} | " f"kd_kl={metrics['kd_kl']:.4f}" )

这个快照让你训练中就能看到「学生离教师还有多远」,而不必等训练结束。quick_eval 内部只跑前 20 个 batch,避免评估拖慢训练。

详细的评估指标含义与对比表,在《第 8 章 评估与对比》展开。

7.8 特色之三:教师输出缓存

这是本项目的一个重要加速手段。当训练步数较多(比如 ≥ 5000 步)时,反复跑教师前向很浪费——教师冻结了,对同一个样本的输出永远不变。于是我们可以预跑一遍教师,把每个样本的 logits 缓存到磁盘,训练时直接查表。

启用方式:

python train.py --cache-teacher

缓存的代价是磁盘空间(每个样本要存 block_size × vocab_size 个 float),收益是训练时不再跑教师前向,速度提升明显(尤其教师较大时)。详细的缓存设计与权衡,在《第 10 章 进阶实战》深入。

7.9 Checkpoint 保存与断点续训

7.9.1 定期保存

save_iter 步保存一次:

def save_checkpoint(student, optimizer, scheduler, cfg, step, loss, checkpoint_dir): path = Path(checkpoint_dir) / f"student_step{step}.pt" torch.save({ "step": step, "student_state_dict": student.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "distill_config": cfg.__dict__, # 配置也存进去! }, path)

注意保存了四样东西:学生权重、优化器状态、调度器状态、配置。前三者用于断点续训,配置用于加载时重建架构。

7.9.2 断点续训

训练可能因各种原因中断(断电、OOM、手动停止)。通过 --resume 可以从中断点无缝继续:

if args.resume: ckpt = torch.load(args.resume, map_location=device) student.load_state_dict(ckpt["student_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_step = ckpt["step"] + 1

续训要恢复的不仅是模型权重,还有优化器和调度器状态——否则学习率曲线会重置,动量信息丢失,续训效果会打折。

# 从第 1000 步的 checkpoint 续训 python train.py --resume checkpoints/student_step1000.pt

7.9.3 最终模型

训练结束额外保存一个 student_final.pt,供评估和推理直接使用:

final_path = Path(cfg.checkpoint_dir) / "student_final.pt" torch.save({ "step": step, "student_state_dict": student.state_dict(), "distill_config": cfg.__dict__, }, final_path)

最终模型只存权重和配置,不存优化器状态——因为它只用来推理/评估,不再训练。

7.10 一次完整训练的日志解读

跑完一次默认训练,你会看到这样的日志(精简):

[train] 使用设备: cuda [model] 加载教师模型: gpt2 [model] 教师参数量: 124.44M (已冻结) [model] 学生配置: 2 层 / 4 头 / 256 维,参数量 10.12M [train] 开始蒸馏: 共 5000 步, T=2.0, alpha=0.5, cache=off 蒸馏: 2%|█▎ | 100/5000 [00:30<24:30] [train] step 10/5000 | total=5.43 ce=4.89 kd=5.97 | lr=2.95e-04 ... [eval] step 1000 | student ppl=45.32 top1=0.342 | teacher ppl=28.15 top1=0.411 | kd_kl=0.0823 ... [train] step 5000/5000 | total=3.21 ce=2.98 kd=3.45 | lr=3.00e-05 [train] checkpoint 已保存: checkpoints/student_step5000.pt [train] 蒸馏训练完成。 [train] 最终学生模型已保存: checkpoints/student_final.pt

读这串日志时关注三点:

  1. total/ce/kd 是否都在下降——健康的训练。
  2. student ppl 是否逐步逼近 teacher ppl——蒸馏在起作用。
  3. lr 曲线是否符合预热+退火——调度正常。

本章小结

  • 训练循环 = 无限迭代器取批次 → 教师 no_grad 前向 → 学生前向 → 蒸馏损失 → 反向+裁剪+更新。
  • 损失分解日志:打印 total/ce/kd 三项,能诊断蒸馏是否健康。
  • 师生对比快照:训练中每 eval_iter 步评估一次,观察学生逼近教师。
  • 教师缓存:预跑教师存 logits,训练时查表,省去重复前向。
  • 断点续训:保存权重+优化器+调度器+配置,--resume 无缝继续。
  • 优化器只接收学生参数,配合 no_grad 和冻结,三重保证只训练学生。

动手实验:跑一次极小配置训练(--max-iters 50 --log-iter 5),观察 total/ce/kd 三项的变化趋势。然后把 --alpha 改成 1.0 再跑一次,对比 kd 项的变化(应被忽略,因为 α=1 等于纯 CE)。

下一站:训练完了,怎么知道效果好不好?在《第 8 章 评估与对比》中,我们用困惑度、Top-1 准确率、分布 KL,给学生和教师做一次全面体检。


发布者: 作者: 转发
评论区 (0)
U