第 10 章 进阶实战


文档摘要

第 10 章 进阶实战 基础流程掌握了,本章讨论工程优化与进阶方向:教师缓存的内部原理、温度与 α 的调参实验、更高级的中间层蒸馏、以及把蒸馏迁移到其他模型族的思路。 10.1 教师输出缓存的工程细节 第 7 章提到过教师缓存能加速训练,本节深入它的实现。 10.1.1 为什么要缓存 教师是冻结的,对同一个样本的前向结果永远不变。如果不缓存,每个 epoch 都要把所有样本喂给教师跑一遍前向——这是巨大的浪费,尤其当教师很大时(比如 gpt2-medium、gpt2-large)。 缓存后,训练时教师前向变成「磁盘查表」,省去全部教师计算。 10.1.2 分片存储:避免单文件过大 一个样本的教师输出是 的 float 张量。

第 10 章 进阶实战

基础流程掌握了,本章讨论工程优化与进阶方向:教师缓存的内部原理、温度与 α 的调参实验、更高级的中间层蒸馏、以及把蒸馏迁移到其他模型族的思路。

10.1 教师输出缓存的工程细节

第 7 章提到过教师缓存能加速训练,本节深入它的实现。

10.1.1 为什么要缓存

教师是冻结的,对同一个样本的前向结果永远不变。如果不缓存,每个 epoch 都要把所有样本喂给教师跑一遍前向——这是巨大的浪费,尤其当教师很大时(比如 gpt2-medium、gpt2-large)。

缓存后,训练时教师前向变成「磁盘查表」,省去全部教师计算。

10.1.2 分片存储:避免单文件过大

一个样本的教师输出是 [block_size, vocab_size] 的 float 张量。以 block_size=128、vocab=50257 为例,单个样本约 128 × 50257 × 4字节 ≈ 25MB。几万个样本就是几百 GB——不可能存成一个大文件。

解决方案是分片:每 N 个样本存一个 .pt 文件(本项目默认 N=4096):

class TeacherCache: def __init__(self, cache_dir, shard_size=4096): self.cache_dir = Path(cache_dir) self.shard_size = shard_size def _shard_path(self, shard_id): return self.cache_dir / f"shard_{shard_id:06d}.pt" def build(self, teacher, dataloader, device): buffers = {} with torch.no_grad(): for x, y, idx in dataloader: logits = teacher(input_ids=x.to(device)).logits.cpu() for i, sample_idx in enumerate(idx.tolist()): shard_id = sample_idx // self.shard_size buffers.setdefault(shard_id, {})[sample_idx] = logits[i] # 满 shard 立即落盘 if len(buffers[shard_id]) >= self.shard_size: self._flush(buffers, shard_id) # 剩余的也落盘 for sid in list(buffers.keys()): self._flush(buffers, sid)

关键设计:

  • 满 shard 即落盘:避免内存堆积,常驻内存只有一个分片。
  • 存 CPU 张量:避免显存爆炸。
  • 按 idx 索引:训练时按 batch 的 idx 取出对应 logits。

10.1.3 训练时查表

def fetch(self, idx_tensor, device): out = [] for i in idx_tensor.tolist(): shard_id = i // self.shard_size if shard_id != self._loaded_shard_id: self._loaded_shard = torch.load(self._shard_path(shard_id)) self._loaded_shard_id = shard_id out.append(self._loaded_shard[i]) return torch.stack(out, dim=0).to(device)

用一个简单的「当前已加载分片」缓存,避免同一个分片被反复读取。

10.1.4 缓存的权衡

维度 不缓存 缓存
训练速度 每步跑教师前向 查表,快很多
磁盘占用 几乎为 0 大(见 10.1.2)
首次启动 直接训练 多一次预跑
适用场景 短训练、教师小 长训练、教师大

决策建议max_iters 较小(如 < 2000)或教师很小时,缓存收益不明显,反增加磁盘负担;max_iters 大(如 ≥ 5000)且教师较大时,缓存能显著加速。本项目的 --cache-teacher 开关让你灵活选择。

10.1.5 缓存构建必须不打乱

构建缓存时用的 DataLoader 必须 shuffle=False,保证 idx 与缓存内容一一对应。构建完成后,训练用的 DataLoader 可以随意打乱——因为查表是按 idx 的,与顺序无关。这是为什么数据集要返回 idx 的根本原因。

10.2 调参实验:温度 T 与权重 α

温度和 α 是影响蒸馏效果的两个核心超参。本节给出调参思路。

10.2.1 温度 T 的影响

T 值 教师分布 蒸馏效果
T=1 尖锐,暗知识少 接近硬标签,蒸馏增益小
T=2~4 适中,暗知识浮现 通常最佳区间
T>8 过于平坦 教师信息被稀释,学生难学

推荐先从 T=2 开始,然后尝试 T=4、T=8,观察 alignment_kl 和学生 PPL 的变化。

10.2.2 α 权重的影响

α 值 偏向 适用
α=0.7~1.0 重硬标签 教师软标签噪声大、或学生容量很小
α=0.3~0.5 重软标签(蒸馏为主) 通常最佳区间
α=0~0.3 几乎纯蒸馏 教师很强、数据少时

10.2.3 一个对照实验方案

做一组正交实验,固定其他参数,只变 T 和 α:

# 基线:纯硬标签 python train.py --alpha 1.0 --max-iters 3000 # 蒸馏:不同温度 python train.py --temperature 2.0 --alpha 0.5 --max-iters 3000 python train.py --temperature 4.0 --alpha 0.5 --max-iters 3000 python train.py --temperature 8.0 --alpha 0.5 --max-iters 3000 # 蒸馏:不同权重 python train.py --temperature 4.0 --alpha 0.3 --max-iters 3000 python train.py --temperature 4.0 --alpha 0.7 --max-iters 3000

每组训完用评估命令评估(见《附录 B 命令速查》),把生成的评估报告汇总成表,就能找到最佳组合。这种控制变量法是调参的金标准。

10.3 学生规模的压缩比实验

学生越大越容易学,但压缩收益越小。可以尝试不同规模:

# 极致压缩(约 3M 参数) python train.py --student-n-layer 1 --student-n-embd 128 --student-n-head 2 # 默认(约 10M 参数) python train.py # 较大学生(约 30M 参数) python train.py --student-n-layer 4 --student-n-embd 512 --student-n-head 8

绘制「参数量 vs PPL」曲线,找到帕累托最优——即用最少参数达到可接受 PPL 的点。这是模型压缩研究的核心问题。

10.4 进阶方向:中间层蒸馏

Logits 蒸馏只对齐「最终输出」。更激进的做法是对齐「中间层」——让学生的隐藏状态、注意力直接模仿教师。代表方法有 TinyBERT、MiniLM。

10.4.1 原理

由于师生层数可能不同(如教师 12 层、学生 2 层),需要做「层映射」——比如把教师的每 6 层对应学生 1 层。对齐损失通常用 MSE:

def hidden_state_loss(student_hidden, teacher_hidden): # student_hidden, teacher_hidden: [batch, seq, dim] # 维度可能不同,需先投影到相同维度 return F.mse_loss(student_hidden, teacher_hidden)

10.4.2 与 Logits 蒸馏的关系

中间层蒸馏可以和 Logits 蒸馏叠加:

L = α·CE + β·KD_logits + γ·(中间层 MSE 之和)

通常能进一步提升压缩比下的性能,代价是实现更复杂(需处理层映射、维度投影)。这是 Logits 蒸馏掌握后的自然进阶方向。

10.5 进阶方向:序列级蒸馏(黑盒教师)

如果教师是一个黑盒 API(如商业大模型),你拿不到 logits,只有生成的文本,怎么办?答案是序列级蒸馏

核心思想:让教师先「生成」大量伪数据,学生再在这些伪数据上做普通语言建模训练。虽然拿不到 logits,但教师的生成文本本身就携带了它的"知识"。

这种方法的优点是不要求教师白盒,缺点是需要先花成本生成大量数据。当教师是 GPT-4 等闭源模型时,这是唯一可行的蒸馏路径。

10.6 迁移到其他模型族

本项目的蒸馏框架(数据集、损失、训练循环)是模型无关的。迁移到其他模型族(BERT、LLaMA、Qwen 等)主要改动:

改动点 从 GPT2 改到 BERT 从 GPT2 改到 LLaMA
模型类 GPT2LMHeadModelBertForMaskedLM LlamaForCausalLM
配置类 GPT2ConfigBertConfig LlamaConfig
分词器 tiktoken → bert tokenizer → LLaMA tokenizer
任务 因果 LM → 掩码 LM 因果 LM(类似)
shift 对齐 需要 需要

损失函数、训练循环、评估逻辑基本可以复用——这正是「配置即代码」+ 模块化设计的收益。

10.7 工程化扩展方向

掌握基础蒸馏后,可以往这些方向扩展:

方向 技术 收益
混合精度训练 torch.cuda.amp 显存减半、速度加倍
分布式训练 DDP / FSDP 多卡加速、训更大模型
参数高效微调 LoRA / QLoRA 教师也可用低秩适配
数据并行缓存 多进程预跑教师 缓存构建加速
在线评估可视化 TensorBoard / W&B 实时看损失曲线

这些属于通用深度学习工程能力,不属于蒸馏特有,但能让蒸馏实验效率大幅提升。建议在跑通基础流程后逐步引入。

10.8 常见调参误区

误区 纠正
温度越高越好 过高会稀释教师信息,T=2~4 通常最佳
α 越小(越重蒸馏)越好 α 太小会忽视真实标签,学生可能学歪
学生越大蒸馏收益越大 学生越接近教师,蒸馏相对增益越小
必须开教师缓存 短训练反而不划算,按场景决策
蒸馏能超越教师 不能。蒸馏只是把教师知识压缩,上限是教师性能

本章小结

  • 教师缓存:分片存储避免单文件过大,满 shard 即落盘,查表时按 idx 取。适合长训练 + 大教师。
  • 调参实验:温度 T 通常 24 最佳,α 通常 0.30.5 最佳;用控制变量法做正交实验。
  • 压缩比实验:绘制参数量-PPL 曲线找帕累托最优。
  • 中间层蒸馏:对齐隐藏状态/注意力,可与 Logits 蒸馏叠加,进一步提升压缩比。
  • 序列级蒸馏:应对黑盒教师,用教师生成伪数据训练学生。
  • 模型族迁移:损失与训练循环可复用,主要换模型/配置/分词器。

动手实验:设计并跑一组温度 × α 的正交实验(至少 4 组),把评估结果汇总成表,找出本项目在你机器上的最佳超参组合。这是把理论转化为实战能力的最好练习。

至此,基础篇与进阶篇全部完成。接下来是三个附录:《附录 A 术语表》汇总全部概念,《附录 B 命令速查》列出常用命令,《附录 C 常见问题》帮你排查报错。


发布者: 作者: 转发
评论区 (0)
U