第 10 章 进阶实战 基础流程掌握了,本章讨论工程优化与进阶方向:教师缓存的内部原理、温度与 α 的调参实验、更高级的中间层蒸馏、以及把蒸馏迁移到其他模型族的思路。 10.1 教师输出缓存的工程细节 第 7 章提到过教师缓存能加速训练,本节深入它的实现。 10.1.1 为什么要缓存 教师是冻结的,对同一个样本的前向结果永远不变。如果不缓存,每个 epoch 都要把所有样本喂给教师跑一遍前向——这是巨大的浪费,尤其当教师很大时(比如 gpt2-medium、gpt2-large)。 缓存后,训练时教师前向变成「磁盘查表」,省去全部教师计算。 10.1.2 分片存储:避免单文件过大 一个样本的教师输出是 的 float 张量。
基础流程掌握了,本章讨论工程优化与进阶方向:教师缓存的内部原理、温度与 α 的调参实验、更高级的中间层蒸馏、以及把蒸馏迁移到其他模型族的思路。
第 7 章提到过教师缓存能加速训练,本节深入它的实现。
教师是冻结的,对同一个样本的前向结果永远不变。如果不缓存,每个 epoch 都要把所有样本喂给教师跑一遍前向——这是巨大的浪费,尤其当教师很大时(比如 gpt2-medium、gpt2-large)。
缓存后,训练时教师前向变成「磁盘查表」,省去全部教师计算。
一个样本的教师输出是 [block_size, vocab_size] 的 float 张量。以 block_size=128、vocab=50257 为例,单个样本约 128 × 50257 × 4字节 ≈ 25MB。几万个样本就是几百 GB——不可能存成一个大文件。
解决方案是分片:每 N 个样本存一个 .pt 文件(本项目默认 N=4096):
class TeacherCache: def __init__(self, cache_dir, shard_size=4096): self.cache_dir = Path(cache_dir) self.shard_size = shard_size def _shard_path(self, shard_id): return self.cache_dir / f"shard_{shard_id:06d}.pt" def build(self, teacher, dataloader, device): buffers = {} with torch.no_grad(): for x, y, idx in dataloader: logits = teacher(input_ids=x.to(device)).logits.cpu() for i, sample_idx in enumerate(idx.tolist()): shard_id = sample_idx // self.shard_size buffers.setdefault(shard_id, {})[sample_idx] = logits[i] # 满 shard 立即落盘 if len(buffers[shard_id]) >= self.shard_size: self._flush(buffers, shard_id) # 剩余的也落盘 for sid in list(buffers.keys()): self._flush(buffers, sid)
关键设计:
def fetch(self, idx_tensor, device): out = [] for i in idx_tensor.tolist(): shard_id = i // self.shard_size if shard_id != self._loaded_shard_id: self._loaded_shard = torch.load(self._shard_path(shard_id)) self._loaded_shard_id = shard_id out.append(self._loaded_shard[i]) return torch.stack(out, dim=0).to(device)
用一个简单的「当前已加载分片」缓存,避免同一个分片被反复读取。
| 维度 | 不缓存 | 缓存 |
|---|---|---|
| 训练速度 | 每步跑教师前向 | 查表,快很多 |
| 磁盘占用 | 几乎为 0 | 大(见 10.1.2) |
| 首次启动 | 直接训练 | 多一次预跑 |
| 适用场景 | 短训练、教师小 | 长训练、教师大 |
决策建议:
max_iters较小(如 < 2000)或教师很小时,缓存收益不明显,反增加磁盘负担;max_iters大(如 ≥ 5000)且教师较大时,缓存能显著加速。本项目的--cache-teacher开关让你灵活选择。
构建缓存时用的 DataLoader 必须 shuffle=False,保证 idx 与缓存内容一一对应。构建完成后,训练用的 DataLoader 可以随意打乱——因为查表是按 idx 的,与顺序无关。这是为什么数据集要返回 idx 的根本原因。
温度和 α 是影响蒸馏效果的两个核心超参。本节给出调参思路。
| T 值 | 教师分布 | 蒸馏效果 |
|---|---|---|
| T=1 | 尖锐,暗知识少 | 接近硬标签,蒸馏增益小 |
| T=2~4 | 适中,暗知识浮现 | 通常最佳区间 |
| T>8 | 过于平坦 | 教师信息被稀释,学生难学 |
推荐先从 T=2 开始,然后尝试 T=4、T=8,观察 alignment_kl 和学生 PPL 的变化。
| α 值 | 偏向 | 适用 |
|---|---|---|
| α=0.7~1.0 | 重硬标签 | 教师软标签噪声大、或学生容量很小 |
| α=0.3~0.5 | 重软标签(蒸馏为主) | 通常最佳区间 |
| α=0~0.3 | 几乎纯蒸馏 | 教师很强、数据少时 |
做一组正交实验,固定其他参数,只变 T 和 α:
# 基线:纯硬标签 python train.py --alpha 1.0 --max-iters 3000 # 蒸馏:不同温度 python train.py --temperature 2.0 --alpha 0.5 --max-iters 3000 python train.py --temperature 4.0 --alpha 0.5 --max-iters 3000 python train.py --temperature 8.0 --alpha 0.5 --max-iters 3000 # 蒸馏:不同权重 python train.py --temperature 4.0 --alpha 0.3 --max-iters 3000 python train.py --temperature 4.0 --alpha 0.7 --max-iters 3000
每组训完用评估命令评估(见《附录 B 命令速查》),把生成的评估报告汇总成表,就能找到最佳组合。这种控制变量法是调参的金标准。
学生越大越容易学,但压缩收益越小。可以尝试不同规模:
# 极致压缩(约 3M 参数) python train.py --student-n-layer 1 --student-n-embd 128 --student-n-head 2 # 默认(约 10M 参数) python train.py # 较大学生(约 30M 参数) python train.py --student-n-layer 4 --student-n-embd 512 --student-n-head 8
绘制「参数量 vs PPL」曲线,找到帕累托最优——即用最少参数达到可接受 PPL 的点。这是模型压缩研究的核心问题。
Logits 蒸馏只对齐「最终输出」。更激进的做法是对齐「中间层」——让学生的隐藏状态、注意力直接模仿教师。代表方法有 TinyBERT、MiniLM。
由于师生层数可能不同(如教师 12 层、学生 2 层),需要做「层映射」——比如把教师的每 6 层对应学生 1 层。对齐损失通常用 MSE:
def hidden_state_loss(student_hidden, teacher_hidden): # student_hidden, teacher_hidden: [batch, seq, dim] # 维度可能不同,需先投影到相同维度 return F.mse_loss(student_hidden, teacher_hidden)
中间层蒸馏可以和 Logits 蒸馏叠加:
L = α·CE + β·KD_logits + γ·(中间层 MSE 之和)
通常能进一步提升压缩比下的性能,代价是实现更复杂(需处理层映射、维度投影)。这是 Logits 蒸馏掌握后的自然进阶方向。
如果教师是一个黑盒 API(如商业大模型),你拿不到 logits,只有生成的文本,怎么办?答案是序列级蒸馏:
核心思想:让教师先「生成」大量伪数据,学生再在这些伪数据上做普通语言建模训练。虽然拿不到 logits,但教师的生成文本本身就携带了它的"知识"。
这种方法的优点是不要求教师白盒,缺点是需要先花成本生成大量数据。当教师是 GPT-4 等闭源模型时,这是唯一可行的蒸馏路径。
本项目的蒸馏框架(数据集、损失、训练循环)是模型无关的。迁移到其他模型族(BERT、LLaMA、Qwen 等)主要改动:
| 改动点 | 从 GPT2 改到 BERT | 从 GPT2 改到 LLaMA |
|---|---|---|
| 模型类 | GPT2LMHeadModel → BertForMaskedLM |
→ LlamaForCausalLM |
| 配置类 | GPT2Config → BertConfig |
→ LlamaConfig |
| 分词器 | tiktoken → bert tokenizer | → LLaMA tokenizer |
| 任务 | 因果 LM → 掩码 LM | 因果 LM(类似) |
| shift 对齐 | 需要 | 需要 |
损失函数、训练循环、评估逻辑基本可以复用——这正是「配置即代码」+ 模块化设计的收益。
掌握基础蒸馏后,可以往这些方向扩展:
| 方向 | 技术 | 收益 |
|---|---|---|
| 混合精度训练 | torch.cuda.amp | 显存减半、速度加倍 |
| 分布式训练 | DDP / FSDP | 多卡加速、训更大模型 |
| 参数高效微调 | LoRA / QLoRA | 教师也可用低秩适配 |
| 数据并行缓存 | 多进程预跑教师 | 缓存构建加速 |
| 在线评估可视化 | TensorBoard / W&B | 实时看损失曲线 |
这些属于通用深度学习工程能力,不属于蒸馏特有,但能让蒸馏实验效率大幅提升。建议在跑通基础流程后逐步引入。
| 误区 | 纠正 |
|---|---|
| 温度越高越好 | 过高会稀释教师信息,T=2~4 通常最佳 |
| α 越小(越重蒸馏)越好 | α 太小会忽视真实标签,学生可能学歪 |
| 学生越大蒸馏收益越大 | 学生越接近教师,蒸馏相对增益越小 |
| 必须开教师缓存 | 短训练反而不划算,按场景决策 |
| 蒸馏能超越教师 | 不能。蒸馏只是把教师知识压缩,上限是教师性能 |
动手实验:设计并跑一组温度 × α 的正交实验(至少 4 组),把评估结果汇总成表,找出本项目在你机器上的最佳超参组合。这是把理论转化为实战能力的最好练习。
至此,基础篇与进阶篇全部完成。接下来是三个附录:《附录 A 术语表》汇总全部概念,《附录 B 命令速查》列出常用命令,《附录 C 常见问题》帮你排查报错。