第 5 章 教师与学生模型


文档摘要

第 5 章 教师与学生模型 蒸馏有两位主角:教师提供"软标签",学生负责学习。本章讲解如何构建这两类模型,以及为什么教师必须被冻结、学生必须从零开始。 5.1 蒸馏中的两个角色 两者的核心差异: 维度 | 教师 | 学生 权重来源 | 预训练(来自 Hub) | 随机初始化 参数量 | 约 124M | 约 10M 是否更新 | 冻结,绝不更新 | 可训练,是优化对象 作用 | 提供软标签 logits | 学习模仿教师 前向开销 | 每步都要跑(除非缓存) | 每步都要跑 5.2 为什么不手写 Transformer 本项目不手写 Transformer,而是直接用 HuggingFace 提供的 。

第 5 章 教师与学生模型

蒸馏有两位主角:教师提供"软标签",学生负责学习。本章讲解如何构建这两类模型,以及为什么教师必须被冻结、学生必须从零开始。

5.1 蒸馏中的两个角色

两者的核心差异:

维度 教师 学生
权重来源 预训练(来自 Hub) 随机初始化
参数量 约 124M 约 10M
是否更新 冻结,绝不更新 可训练,是优化对象
作用 提供软标签 logits 学习模仿教师
前向开销 每步都要跑(除非缓存) 每步都要跑

5.2 为什么不手写 Transformer

本项目不手写 Transformer,而是直接用 HuggingFace transformers 提供的 GPT2LMHeadModel。原因有三:

  1. 工业级稳定:transformers 的 GPT2 实现经过海量用户验证,数值稳定、边界处理完善。
  2. 规模可配:通过 GPT2Config 可以自由控制层数、头数、维度,从迷你版到标准版都能表达。
  3. 聚焦蒸馏本身:把精力放在蒸馏流程上,而非重复造轮子。

这也是本教程与「从零手写 Transformer」类教程的定位差异——我们假设你已经了解 Transformer 大致结构,这里专注工程化与蒸馏。

5.3 构建教师:加载预训练模型

教师用 from_pretrained 从 HuggingFace Hub 加载预训练权重:

from transformers import GPT2LMHeadModel def build_teacher(cfg): print(f"加载教师模型: {cfg.teacher_name}") teacher = GPT2LMHeadModel.from_pretrained(cfg.teacher_name) teacher.eval() # 关闭 dropout for p in teacher.parameters(): p.requires_grad = False # 冻结全部参数 return teacher

两个关键操作:

5.3.1 teacher.eval()

把模型切到推理模式,关闭 dropout 和 batchnorm 的训练行为。教师只是"出题",不需要训练时的随机性。

5.3.2 冻结参数

for p in teacher.parameters(): p.requires_grad = False

把所有参数的 requires_grad 设为 False,意味着反向传播时不会为教师计算梯度,优化器也不会更新它。这是蒸馏的核心约定——教师是固定的参照物

即便忘了冻结,训练循环里教师的前向也会包在 torch.no_grad() 里(见第 7 章),双重保险避免教师被更新。但显式冻结更清晰,也省显存(不存中间激活用于反传)。

5.4 构建学生:从零初始化

学生用 GPT2Config 自定义架构,从零初始化:

from transformers import GPT2LMHeadModel, GPT2Config def build_student(cfg): # 用项目配置构造 HuggingFace 配置 hf_config = GPT2Config(**cfg.to_student_gpt2_kwargs()) # 从零初始化(权重随机) student = GPT2LMHeadModel(hf_config) return student

注意这里没有用 from_pretrained,而是 GPT2LMHeadModel(hf_config) 直接构造。两者的区别:

方式 权重 用途
GPT2LMHeadModel(config) 随机初始化 从零训练(学生、研究)
GPT2LMHeadModel.from_pretrained(name) 加载预训练 微调、当教师

学生的随机初始化意味着它一开始什么都不知道——这正是蒸馏的起点:让一张白纸的学生,通过模仿教师,快速获得语言能力。

5.5 教师与学生的架构对比

默认配置下,两者的规模差距显著:

架构项 教师(标准 gpt2) 学生(默认配置) 比例
层数 12 2 6:1
注意力头数 12 4 3:1
嵌入维度 768 256 3:1
参数量 ≈ 124M ≈ 10M ≈ 12:1
上下文长度 1024 128(可配)

这个压缩比(约 12:1)既保证了明显的体积优势,又不至于小到无法学习。你可以通过命令行轻松调整学生规模:

# 更大的学生(接近教师的一半) python train.py --student-n-layer 6 --student-n-embd 512 --student-n-head 8 # 更小的学生(极致压缩) python train.py --student-n-layer 1 --student-n-embd 128 --student-n-head 2

学生规模的权衡:越大越容易学(逼近教师性能),但压缩收益越小;越小压缩越狠,但性能损失越大。默认值是一个经过实践验证的平衡点。

5.6 加载训练好的学生

训练完成后,学生权重存在 checkpoint 里。评估和推理时需要把它加载回来。本项目支持多种 checkpoint 格式:

import torch def load_student(checkpoint_path, cfg=None, device=None): if cfg is None: cfg = DistillConfig() if os.path.isdir(checkpoint_path): # 格式一:transformers 风格目录 model = GPT2LMHeadModel.from_pretrained(checkpoint_path) else: # 格式二:单个权重文件 model = build_student(cfg) state = torch.load(checkpoint_path, map_location="cpu") # 兼容多种保存键名 if isinstance(state, dict): if "student_state_dict" in state: state = state["student_state_dict"] elif "model_state_dict" in state: state = state["model_state_dict"] model.load_state_dict(state, strict=False) if device: model.to(device) model.eval() return model

注意几个健壮性设计:

  • 兼容多种键名:本项目存的是 student_state_dict,但为了兼容可能存了 model_state_dict 的旧 checkpoint,两个键名都尝试。
  • strict=False:允许少量键不匹配(比如词表 embed 层),避免因细微差异导致加载失败。
  • 优先从 checkpoint 读配置:在《第 8 章》我们会看到,加载时会优先用 checkpoint 内保存的 distill_config 来重建学生架构,保证架构一致。

5.7 参数量统计

实现一个通用的参数量统计函数,方便对比:

def count_parameters(model): """统计可训练参数数量。""" return sum(p.numel() for p in model.parameters() if p.requires_grad)

注意它统计的是 requires_grad=True 的参数。对于冻结的教师,这个值会是 0——这是一个很好的自检手段:如果统计教师参数量得到非零值,说明冻结没生效。

实际使用:

teacher = build_teacher(cfg) student = build_student(cfg) print(f"教师可训练参数: {count_parameters(teacher)}") # 应为 0 print(f"学生可训练参数: {count_parameters(student)}") # 约 10M

5.8 一个完整的构建示例

把教师和学生都构建出来,对比它们的输出:

import torch from config import DistillConfig from model import build_teacher, build_student, count_parameters cfg = DistillConfig() # 构建两个模型 teacher = build_teacher(cfg) student = build_student(cfg) print(f"教师可训练参数: {count_parameters(teacher)/1e6:.2f}M (应为 0)") print(f"学生可训练参数: {count_parameters(student)/1e6:.2f}M") # 同一输入,看两者输出形状一致(都是 [batch, seq, vocab]) dummy = torch.randint(0, cfg.vocab_size, (2, 8)) with torch.no_grad(): t_logits = teacher(input_ids=dummy).logits s_logits = student(input_ids=dummy).logits print(f"教师 logits 形状: {t_logits.shape}") # [2, 8, 50257] print(f"学生 logits 形状: {s_logits.shape}") # [2, 8, 50257]

可以看到,尽管参数量差 12 倍,两者的输入输出接口完全一致——这正是 Logits 蒸馏能成立的前提:师生输出形状相同,才能逐元素对齐做 KL

本章小结

  • 教师用 from_pretrained 加载预训练权重,必须 eval() 并冻结所有参数。
  • 学生用 GPT2Config 自定义架构,从零随机初始化,是唯一的训练对象。
  • 默认压缩比约 12:1(124M → 10M),可通过命令行灵活调整。
  • count_parameters 对冻结模型应返回 0,是冻结生效的快速自检。
  • 师生输出接口一致([batch, seq, vocab]),这是 Logits 蒸馏的前提。

动手实验:构建教师和学生,对比它们在同一输入上的 logits 分布(用 softmax 后看 top-5 token 概率)。你会直观看到教师的分布更"合理",学生的是随机的——这就是蒸馏要弥合的差距。

下一站:模型就位,进入全教程的代码核心。在《第 6 章 蒸馏损失函数》中,我们将逐行实现软硬标签组合损失,并彻底搞懂 shift 对齐和 KL 方向这两个最容易踩的坑。


发布者: 作者: 转发
评论区 (0)
U