第 6 章 蒸馏损失函数(数学核心) 这是全教程代码的灵魂。本章逐行实现蒸馏损失,并彻底讲透三个最容易踩坑的细节:shift 对齐、KL 散度的方向、T² 缩放。如果你只读一章代码章节,就读这一章。 6.1 损失函数的全貌 回顾第 1 章的蒸馏总损失公式: 我们要实现的函数签名如下,返回三元组(总损失、CE、KL),方便训练时分别监控: 下面逐个拆解。在动手写之前,必须先解决一个前置问题:shift 对齐。 6.2 前置坑之一:shift 对齐 这是 GPT2 类自回归模型最容易让人困惑的细节。 问题描述 GPT2 的语言建模约定是:用位置 t 的输入 token,去预测位置 t+1 的 token。所以模型输出的 ,对应的预测目标是「输入序列中第 t+1 个 token」。
这是全教程代码的灵魂。本章逐行实现蒸馏损失,并彻底讲透三个最容易踩坑的细节:shift 对齐、KL 散度的方向、T² 缩放。如果你只读一章代码章节,就读这一章。
回顾第 1 章的蒸馏总损失公式:
L_total = α · CE(student, 硬标签) + (1 - α) · T² · KL(teacher_soft ‖ student_soft)
我们要实现的函数签名如下,返回三元组(总损失、CE、KL),方便训练时分别监控:
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha): ce_loss = compute_ce_loss(student_logits, labels) kd_loss = compute_kd_loss(student_logits, teacher_logits, temperature) total_loss = alpha * ce_loss + (1.0 - alpha) * kd_loss return total_loss, ce_loss, kd_loss
下面逐个拆解。在动手写之前,必须先解决一个前置问题:shift 对齐。
这是 GPT2 类自回归模型最容易让人困惑的细节。
GPT2 的语言建模约定是:用位置 t 的输入 token,去预测位置 t+1 的 token。所以模型输出的 logits[..., t, :],对应的预测目标是「输入序列中第 t+1 个 token」。
输入 x: [t0, t1, t2, t3] ← 4 个 token 位置: 0 1 2 3 logits: [L0, L1, L2, L3] ← L_i 是用 x[i] 预测下一个的分布 │ │ │ │ ▼ ▼ ▼ ▼ 预测目标: t1 t2 t3 (t4 不在序列里)
也就是说,logits[..., :-1, :](去掉最后一个)才和 labels[..., 1:](去掉第一个)一一对应。
def shift_logits_labels(logits, labels): """对齐 GPT2 的 next-token 预测。""" return logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
如果忘了 shift,CE 和 KL 都会错位——你在用「预测 t1 的分布」去和「真实 t0」比较,损失计算完全错误,但不会报错(形状对得上),这种 bug 极其隐蔽,会让训练效果大打折扣却难以察觉。
经验法则:只要你用
transformers的 GPT2/BERT 等 LM head 模型,又自己计算损失,就必须 shift。如果用模型自带的labels参数(让模型自己算 loss),它内部已经 shift 了,就不需要再手动 shift。本项目是手动算损失,所以必须 shift。
硬标签项和普通训练完全一样——让学生预测真实的下一个 token。
import torch.nn.functional as F def compute_ce_loss(student_logits, labels): pred_logits, gold_labels = shift_logits_labels(student_logits, labels) V = pred_logits.size(-1) return F.cross_entropy( pred_logits.view(-1, V), # 展平为 [batch*(T-1), V] gold_labels.view(-1), # 展平为 [batch*(T-1)] ignore_index=-100, # 忽略 padding 位置 )
要点:
cross_entropy 要求输入是 [N, C]、目标是 [N],所以把 batch 和序列维展平。ignore_index=-100:约定 -100 是忽略标记,这些位置不计入损失(本项目数据集没用 padding,但保留这个约定以防扩展)。这是蒸馏的灵魂。先看实现,再逐行解释:
def compute_kd_loss(student_logits, teacher_logits, temperature): # 1. shift 对齐(教师与学生都截末位) s_logits, _ = shift_logits_labels( student_logits, torch.zeros(student_logits.shape[:2], device=student_logits.device) ) t_logits, _ = shift_logits_labels( teacher_logits, torch.zeros(teacher_logits.shape[:2], device=teacher_logits.device) ) T = max(float(temperature), 1e-3) # 2. 教师的软化分布 soft_teacher = F.softmax(t_logits / T, dim=-1) # 3.学生的对数软化分布 log_soft_student = F.log_softmax(s_logits / T, dim=-1) # 4. KL 散度,乘 T² 补偿梯度 kd = F.kl_div(log_soft_student, soft_teacher, reduction="batchmean") * (T * T) return kd
log_softmax 而不是 softmax注意学生用的是 log_softmax(对数概率),教师用的是普通 softmax(概率)。这不是随意的,而是 F.kl_div 的接口要求——下面解释。
F.kl_div 的方向陷阱(前置坑之二)F.kl_div 的签名是:
F.kl_div(log_input, target, reduction=...)
它内部计算的是:
Σ target · (log(target) - log_input) = KL(target ‖ input)
也就是说,它算的是 KL(target ‖ input)。所以参数顺序至关重要:
| 我们想要的 | 第一个参数 (log_input) | 第二个参数 (target) |
|---|---|---|
KL(teacher ‖ student) |
学生的 log_softmax |
教师的 softmax |
我们希望「让学生逼近教师」,即最小化 KL(teacher ‖ student),所以:
F.kl_div(log_soft_student, soft_teacher, ...) # ✓ 正确 F.kl_div(soft_teacher, log_soft_student, ...) # ✗ 类型都不对(第一个要 log)
口诀:
kl_div第一个参数传「学生的 log 概率」,第二个传「教师的概率」。这是新手最容易传反的地方,务必记住。
reduction="batchmean" 的选择F.kl_div 的 reduction 有三种:
| 取值 | 行为 | 适用 |
|---|---|---|
"sum" |
所有元素求和 | 受 batch 大小影响,不便比较 |
"mean" |
所有元素求平均(含词表维) | 会把 KL 缩小 V 倍,不推荐 |
"batchmean" |
先求和再除以 batch 内 token 数 | 推荐,得到「平均每个 token 的 KL」 |
注意 "mean" 会除以 batch × seq × vocab,把 KL 缩小了 vocab 倍(GPT2 词表 50257),数值会变得极小,损失信号被淹没。所以蒸馏场景必须用 "batchmean"。
这里的「除以 token 数」实际上由
batchmean在 batch 维度取均值近似实现。严格来说应除以有效 token 数,但在固定序列长度下两者等价。
最后乘上 T * T:
kd = F.kl_div(...) * (T * T)
原理在第 1 章已详细推导:温度软化后梯度被压缩约 1/T,乘 T² 补偿,让软标签项的优化力度与硬标签项平衡。这一行绝不能漏,否则温度越高蒸馏信号越弱,等于没蒸馏。
把两项组合,加权求和:
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha): ce_loss = compute_ce_loss(student_logits, labels) kd_loss = compute_kd_loss(student_logits, teacher_logits, temperature) total_loss = alpha * ce_loss + (1.0 - alpha) * kd_loss return total_loss, ce_loss, kd_loss
返回三元组 (total, ce, kd) 是一个重要的工程设计:训练循环里可以同时记录三项,观察它们的相对变化,从而判断蒸馏是否健康(详见第 7 章)。
除了损失,我们还提供一个独立的「分布 KL」函数,用于评估阶段衡量学生学得有多像教师:
def distribution_kl(logits_a, logits_b, temperature=1.0): """计算两组 logits 的 softmax 分布的平均 KL(KL(a‖b))。""" T = max(float(temperature), 1e-3) soft_a = F.softmax(logits_a / T, dim=-1) log_soft_b = F.log_softmax(logits_b / T, dim=-1) kl = F.kl_div(log_soft_b, soft_a, reduction="batchmean") return kl.item()
它和 compute_kd_loss 的区别:
| 函数 | 用途 | 是否参与反向传播 | 温度 |
|---|---|---|---|
compute_kd_loss |
训练损失 | 是(返回 tensor) | 蒸馏温度 T |
distribution_kl |
评估指标 | 否(返回 float) | 默认 1.0(原始分布) |
评估时通常用 T=1(原始分布)来衡量,反映学生「真实输出」和教师差多远。这个函数在《第 8 章 评估》会用到。
不需要真实模型,用随机数据就能验证损失函数的正确性:
import torch from losses import distillation_loss, distribution_kl B, T_len, V = 2, 5, 1000 # batch=2, 序列长=5, 词表=1000 student_logits = torch.randn(B, T_len, V) teacher_logits = torch.randn(B, T_len, V) labels = torch.randint(0, V, (B, T_len)) total, ce, kd = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5 ) print(f"total={total.item():.4f}, ce={ce.item():.4f}, kd={kd.item():.4f}") print(f"distribution_kl = {distribution_kl(student_logits, teacher_logits):.4f}")
几个验证点:
kd_same = distillation_loss( teacher_logits, teacher_logits, labels, temperature=2.0, alpha=0.5 )[2] print(f"学生=教师时的 kd: {kd_same.item():.6f}") # 应接近 0
total_alpha1, _, _ = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=1.0 ) print(f"alpha=1 的 total: {total_alpha1.item():.4f}") # 应等于 ce
total_alpha0, _, _ = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=0.0 ) print(f"alpha=0 的 total: {total_alpha0.item():.4f}") # 应等于 kd
这三个检查能帮你确认实现无误。本项目的代码已通过这些验证。
| 坑 | 表现 | 正确做法 |
|---|---|---|
| 忘 shift | CE、KL 全部错位,训练效果差但不报错 | logits 截末位、labels 截首位 |
| KL 方向传反 | 蒸馏方向错误,学生学歪 | kl_div(学生log_softmax, 教师softmax) |
| 漏乘 T² | 温度越高信号越弱 | KL 结果乘 T * T |
这三个坑是蒸馏实现的"三大陷阱",掉进任何一个都会让效果大打折扣却难以察觉。本章把它们都讲透了,并在代码中明确标注。
F.kl_div(学生 log_softmax, 教师 softmax),算的是 KL(教师‖学生)。"batchmean",避免被词表维度稀释。T*T,补偿温度对梯度的压缩。动手实验:用随机数据跑一遍损失函数,并做 6.7 节的三个健康性检查。确认「学生=教师时 KL≈0」「α=1 退化为 CE」「α=0 退化为 KL·T²」。
下一站:损失函数就位,所有零件齐了。在《第 7 章 训练主循环》中,我们把数据、模型、损失、优化器组装成一个完整的训练循环,并实现损失分解日志与断点续训。