第 7 章 训练主循环 所有零件都备齐了——数据、教师、学生、损失。本章把它们组装成一个完整的训练循环,并实现本项目最有特色的几个工程能力:损失分解日志、师生对比快照、教师输出缓存、断点续训。 7.1 训练循环的全貌 本章按这个流程逐步拆解。 7.2 初始化:配置、种子、设备 训练入口的前三步是「打地基」: 7.2.1 随机种子:可复现性的基础 固定三个随机源(Python、NumPy、PyTorch),让相同配置 + 相同种子能复现完全一致的结果。这是做实验对比的前提——否则你无法判断效果差异来自参数还是来自随机性。 7.2.2 设备选择 默认 ,有 GPU 用 GPU,否则退回 CPU。 7.3 加载数据与模型 这两步分别在第 4、5 章讲过,这里只是组装。
所有零件都备齐了——数据、教师、学生、损失。本章把它们组装成一个完整的训练循环,并实现本项目最有特色的几个工程能力:损失分解日志、师生对比快照、教师输出缓存、断点续训。
本章按这个流程逐步拆解。
训练入口的前三步是「打地基」:
def train(): args = parse_args() # 解析命令行 cfg = DistillConfig() apply_args_to_config(args, cfg) # 覆盖默认配置 set_seed(cfg.seed) # 固定随机种子 device = select_device(cfg.device)
import random, numpy as np, torch def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
固定三个随机源(Python、NumPy、PyTorch),让相同配置 + 相同种子能复现完全一致的结果。这是做实验对比的前提——否则你无法判断效果差异来自参数还是来自随机性。
def select_device(name): if name == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(name)
默认 auto,有 GPU 用 GPU,否则退回 CPU。
# 数据:训练集 + 验证集 text = get_dataset() train_loader, eval_loader, train_dataset = build_train_eval_loaders( text=text, block_size=cfg.block_size, batch_size=cfg.batch_size, num_workers=cfg.num_workers, eval_ratio=cfg.eval_ratio, ) # 教师(冻结)+ 学生(可训练) teacher = build_teacher(cfg) teacher.to(device) student = build_student(cfg) student.to(device)
这两步分别在第 4、5 章讲过,这里只是组装。注意教师和学生都要 .to(device),否则一个在 GPU 一个在 CPU 会导致数据搬运混乱。
from torch.optim import AdamW optimizer = AdamW( student.parameters(), # 只优化学生! lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay, )
关键点:optimizer 只接收 student.parameters()。教师被排除在外,即使忘了冻结也不会被更新。这是「只训练学生」的最终保证。
学习率曲线采用「线性预热 + 余弦退火」两段式:
import math from torch.optim.lr_scheduler import LambdaLR def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1): def lr_lambda(current_step): # 1) 预热:线性升温 if current_step < num_warmup_steps: return float(current_step) / max(1, num_warmup_steps) # 2) 余弦退火:平滑下降 progress = (current_step - num_warmup_steps) / max( 1, num_training_steps - num_warmup_steps ) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine return LambdaLR(optimizer, lr_lambda)
两个阶段的作用:
| 阶段 | 学习率变化 | 为什么 |
|---|---|---|
| 预热(warmup) | 0 → 峰值 | 初始权重随机,大学习率会破坏训练;先小步试探 |
| 余弦退火 | 峰值 → min_lr | 平滑下降,后期精调,避免在最优解附近震荡 |
这是 GPT-2/3 训练的标准配方,对蒸馏同样适用。
这是整个项目的「心脏」:
student.train() running_total = running_ce = running_kd = 0.0 step = 0 data_iter = iter(train_loader) while step < cfg.max_iters: # 取一个批次(epoch 用尽则重启迭代器) try: x, y, idx = next(data_iter) except StopIteration: data_iter = iter(train_loader) x, y, idx = next(data_iter) x, y, idx = x.to(device), y.to(device), idx.to(device) # ① 教师前向(不计算梯度) if use_cache: teacher_logits = teacher_cache.fetch(idx, device) else: with torch.no_grad(): teacher_logits = teacher(input_ids=x).logits # ② 学生前向 student_logits = student(input_ids=x).logits # ③ 蒸馏损失(含 CE + KL 分解) total_loss, ce_loss, kd_loss = distillation_loss( student_logits, teacher_logits, y, cfg.temperature, cfg.alpha ) # ④ 反向传播 + 梯度裁剪 + 更新 optimizer.zero_grad(set_to_none=True) total_loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), cfg.grad_clip) optimizer.step() scheduler.step() # ⑤ 统计与日志(见 7.6) ...
逐步解释关键点:
no_gradwith torch.no_grad(): teacher_logits = teacher(input_ids=x).logits
torch.no_grad() 让教师前向不构建计算图,既省显存又省算力。配合第 5 章的参数冻结,双保险确保教师不被更新。
set_to_none=Trueoptimizer.zero_grad(set_to_none=True)
把梯度设为 None 而非 0,省一次全零写入,略快一点。这是 PyTorch 推荐的现代写法。
torch.nn.utils.clip_grad_norm_(student.parameters(), cfg.grad_clip)
把所有梯度的整体范数裁剪到 grad_clip(默认 1.0)以内。训练初期学生权重随机,偶尔会出现梯度爆炸,裁剪能稳定训练。这是 Transformer 训练的标配。
try: x, y, idx = next(data_iter) except StopIteration: data_iter = iter(train_loader) # epoch 用尽,重新开始 x, y, idx = next(data_iter)
本项目按「步数」(max_iters)而非「epoch」控制训练长度。当数据集迭代完一个 epoch,自动重启迭代器继续下一个 epoch,直到步数达标。这种模式在数据量大、想精确控制训练步数时很常用。
每 log_iter 步打印一次三项损失的均值:
running_total += total_loss.item() running_ce += ce_loss.item() running_kd += kd_loss.item() step += 1 if step % cfg.log_iter == 0: avg_total = running_total / cfg.log_iter avg_ce = running_ce / cfg.log_iter avg_kd = running_kd / cfg.log_iter print( f"[train] step {step}/{cfg.max_iters} | " f"total={avg_total:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | " f"lr={scheduler.get_last_lr()[0]:.2e}" ) running_total = running_ce = running_kd = 0.0
输出示例:
[train] step 10/5000 | total=5.43 ce=4.89 kd=5.97 | lr=2.95e-04
为什么要分三项打印?因为它们反映了蒸馏的不同侧面:
| 指标 | 反映 | 健康趋势 |
|---|---|---|
total |
总优化目标 | 持续下降 |
ce |
学生对真实标签的拟合 | 下降,但不能为 0(过拟合) |
kd |
学生对教师分布的模仿 | 下降,趋近 0 说明学得像 |
如果 ce 下降但 kd 不降,说明学生在学硬标签却没学教师,蒸馏没起作用——可能是温度太低或 α 太高。这种诊断能力,正是损失分解的价值。
每 eval_iter 步,在验证集上做一次快速评估,对比师生表现:
if cfg.eval_iter > 0 and step % cfg.eval_iter == 0: metrics = quick_eval(student, teacher, eval_loader, device) print( f"[eval] step {step} | " f"student ppl={metrics['student_ppl']:.2f} top1={metrics['student_top1']:.3f} | " f"teacher ppl={metrics['teacher_ppl']:.2f} top1={metrics['teacher_top1']:.3f} | " f"kd_kl={metrics['kd_kl']:.4f}" )
这个快照让你训练中就能看到「学生离教师还有多远」,而不必等训练结束。quick_eval 内部只跑前 20 个 batch,避免评估拖慢训练。
详细的评估指标含义与对比表,在《第 8 章 评估与对比》展开。
这是本项目的一个重要加速手段。当训练步数较多(比如 ≥ 5000 步)时,反复跑教师前向很浪费——教师冻结了,对同一个样本的输出永远不变。于是我们可以预跑一遍教师,把每个样本的 logits 缓存到磁盘,训练时直接查表。
启用方式:
python train.py --cache-teacher
缓存的代价是磁盘空间(每个样本要存 block_size × vocab_size 个 float),收益是训练时不再跑教师前向,速度提升明显(尤其教师较大时)。详细的缓存设计与权衡,在《第 10 章 进阶实战》深入。
每 save_iter 步保存一次:
def save_checkpoint(student, optimizer, scheduler, cfg, step, loss, checkpoint_dir): path = Path(checkpoint_dir) / f"student_step{step}.pt" torch.save({ "step": step, "student_state_dict": student.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "distill_config": cfg.__dict__, # 配置也存进去! }, path)
注意保存了四样东西:学生权重、优化器状态、调度器状态、配置。前三者用于断点续训,配置用于加载时重建架构。
训练可能因各种原因中断(断电、OOM、手动停止)。通过 --resume 可以从中断点无缝继续:
if args.resume: ckpt = torch.load(args.resume, map_location=device) student.load_state_dict(ckpt["student_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_step = ckpt["step"] + 1
续训要恢复的不仅是模型权重,还有优化器和调度器状态——否则学习率曲线会重置,动量信息丢失,续训效果会打折。
# 从第 1000 步的 checkpoint 续训 python train.py --resume checkpoints/student_step1000.pt
训练结束额外保存一个 student_final.pt,供评估和推理直接使用:
final_path = Path(cfg.checkpoint_dir) / "student_final.pt" torch.save({ "step": step, "student_state_dict": student.state_dict(), "distill_config": cfg.__dict__, }, final_path)
最终模型只存权重和配置,不存优化器状态——因为它只用来推理/评估,不再训练。
跑完一次默认训练,你会看到这样的日志(精简):
[train] 使用设备: cuda [model] 加载教师模型: gpt2 [model] 教师参数量: 124.44M (已冻结) [model] 学生配置: 2 层 / 4 头 / 256 维,参数量 10.12M [train] 开始蒸馏: 共 5000 步, T=2.0, alpha=0.5, cache=off 蒸馏: 2%|█▎ | 100/5000 [00:30<24:30] [train] step 10/5000 | total=5.43 ce=4.89 kd=5.97 | lr=2.95e-04 ... [eval] step 1000 | student ppl=45.32 top1=0.342 | teacher ppl=28.15 top1=0.411 | kd_kl=0.0823 ... [train] step 5000/5000 | total=3.21 ce=2.98 kd=3.45 | lr=3.00e-05 [train] checkpoint 已保存: checkpoints/student_step5000.pt [train] 蒸馏训练完成。 [train] 最终学生模型已保存: checkpoints/student_final.pt
读这串日志时关注三点:
--resume 无缝继续。动手实验:跑一次极小配置训练(--max-iters 50 --log-iter 5),观察 total/ce/kd 三项的变化趋势。然后把 --alpha 改成 1.0 再跑一次,对比 kd 项的变化(应被忽略,因为 α=1 等于纯 CE)。
下一站:训练完了,怎么知道效果好不好?在《第 8 章 评估与对比》中,我们用困惑度、Top-1 准确率、分布 KL,给学生和教师做一次全面体检。