第 6 章 蒸馏损失函数(数学核心)


文档摘要

第 6 章 蒸馏损失函数(数学核心) 这是全教程代码的灵魂。本章逐行实现蒸馏损失,并彻底讲透三个最容易踩坑的细节:shift 对齐、KL 散度的方向、T² 缩放。如果你只读一章代码章节,就读这一章。 6.1 损失函数的全貌 回顾第 1 章的蒸馏总损失公式: 我们要实现的函数签名如下,返回三元组(总损失、CE、KL),方便训练时分别监控: 下面逐个拆解。在动手写之前,必须先解决一个前置问题:shift 对齐。 6.2 前置坑之一:shift 对齐 这是 GPT2 类自回归模型最容易让人困惑的细节。 问题描述 GPT2 的语言建模约定是:用位置 t 的输入 token,去预测位置 t+1 的 token。所以模型输出的 ,对应的预测目标是「输入序列中第 t+1 个 token」。

第 6 章 蒸馏损失函数(数学核心)

这是全教程代码的灵魂。本章逐行实现蒸馏损失,并彻底讲透三个最容易踩坑的细节:shift 对齐、KL 散度的方向、T² 缩放。如果你只读一章代码章节,就读这一章。

6.1 损失函数的全貌

回顾第 1 章的蒸馏总损失公式:

L_total = α · CE(student, 硬标签) + (1 - α) · T² · KL(teacher_soft ‖ student_soft)

我们要实现的函数签名如下,返回三元组(总损失、CE、KL),方便训练时分别监控:

def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha): ce_loss = compute_ce_loss(student_logits, labels) kd_loss = compute_kd_loss(student_logits, teacher_logits, temperature) total_loss = alpha * ce_loss + (1.0 - alpha) * kd_loss return total_loss, ce_loss, kd_loss

下面逐个拆解。在动手写之前,必须先解决一个前置问题:shift 对齐

6.2 前置坑之一:shift 对齐

这是 GPT2 类自回归模型最容易让人困惑的细节。

问题描述

GPT2 的语言建模约定是:用位置 t 的输入 token,去预测位置 t+1 的 token。所以模型输出的 logits[..., t, :],对应的预测目标是「输入序列中第 t+1 个 token」。

输入 x: [t0, t1, t2, t3] ← 4 个 token 位置: 0 1 2 3 logits: [L0, L1, L2, L3] ← L_i 是用 x[i] 预测下一个的分布 │ │ │ │ ▼ ▼ ▼ ▼ 预测目标: t1 t2 t3 (t4 不在序列里)

也就是说,logits[..., :-1, :](去掉最后一个)才和 labels[..., 1:](去掉第一个)一一对应。

实现

def shift_logits_labels(logits, labels): """对齐 GPT2 的 next-token 预测。""" return logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
  • logits 截掉最后一位(因为最后一个位置没有对应的「下一个 token」标签)。
  • labels 截掉第一位(因为第一位是序列开头,没有 logits 去预测它)。

为什么重要

如果忘了 shift,CE 和 KL 都会错位——你在用「预测 t1 的分布」去和「真实 t0」比较,损失计算完全错误,但不会报错(形状对得上),这种 bug 极其隐蔽,会让训练效果大打折扣却难以察觉。

经验法则:只要你用 transformers 的 GPT2/BERT 等 LM head 模型,又自己计算损失,就必须 shift。如果用模型自带的 labels 参数(让模型自己算 loss),它内部已经 shift 了,就不需要再手动 shift。本项目是手动算损失,所以必须 shift。

6.3 硬标签项:交叉熵 CE

硬标签项和普通训练完全一样——让学生预测真实的下一个 token。

import torch.nn.functional as F def compute_ce_loss(student_logits, labels): pred_logits, gold_labels = shift_logits_labels(student_logits, labels) V = pred_logits.size(-1) return F.cross_entropy( pred_logits.view(-1, V), # 展平为 [batch*(T-1), V] gold_labels.view(-1), # 展平为 [batch*(T-1)] ignore_index=-100, # 忽略 padding 位置 )

要点:

  • 展平cross_entropy 要求输入是 [N, C]、目标是 [N],所以把 batch 和序列维展平。
  • ignore_index=-100:约定 -100 是忽略标记,这些位置不计入损失(本项目数据集没用 padding,但保留这个约定以防扩展)。

6.4 软标签项:KL 散度(核心)

这是蒸馏的灵魂。先看实现,再逐行解释:

def compute_kd_loss(student_logits, teacher_logits, temperature): # 1. shift 对齐(教师与学生都截末位) s_logits, _ = shift_logits_labels( student_logits, torch.zeros(student_logits.shape[:2], device=student_logits.device) ) t_logits, _ = shift_logits_labels( teacher_logits, torch.zeros(teacher_logits.shape[:2], device=teacher_logits.device) ) T = max(float(temperature), 1e-3) # 2. 教师的软化分布 soft_teacher = F.softmax(t_logits / T, dim=-1) # 3.学生的对数软化分布 log_soft_student = F.log_softmax(s_logits / T, dim=-1) # 4. KL 散度,乘 T² 补偿梯度 kd = F.kl_div(log_soft_student, soft_teacher, reduction="batchmean") * (T * T) return kd

6.4.1 为什么要用 log_softmax 而不是 softmax

注意学生用的是 log_softmax(对数概率),教师用的是普通 softmax(概率)。这不是随意的,而是 F.kl_div 的接口要求——下面解释。

6.4.2 F.kl_div 的方向陷阱(前置坑之二)

F.kl_div 的签名是:

F.kl_div(log_input, target, reduction=...)

它内部计算的是:

Σ target · (log(target) - log_input) = KL(target ‖ input)

也就是说,它算的是 KL(target ‖ input)。所以参数顺序至关重要:

我们想要的 第一个参数 (log_input) 第二个参数 (target)
KL(teacher ‖ student) 学生的 log_softmax 教师的 softmax

我们希望「让学生逼近教师」,即最小化 KL(teacher ‖ student),所以:

F.kl_div(log_soft_student, soft_teacher, ...) # ✓ 正确 F.kl_div(soft_teacher, log_soft_student, ...) # ✗ 类型都不对(第一个要 log)

口诀kl_div 第一个参数传「学生的 log 概率」,第二个传「教师的概率」。这是新手最容易传反的地方,务必记住。

6.4.3 reduction="batchmean" 的选择

F.kl_divreduction 有三种:

取值 行为 适用
"sum" 所有元素求和 受 batch 大小影响,不便比较
"mean" 所有元素求平均(含词表维) 会把 KL 缩小 V 倍,不推荐
"batchmean" 先求和再除以 batch 内 token 数 推荐,得到「平均每个 token 的 KL」

注意 "mean" 会除以 batch × seq × vocab,把 KL 缩小了 vocab 倍(GPT2 词表 50257),数值会变得极小,损失信号被淹没。所以蒸馏场景必须用 "batchmean"

这里的「除以 token 数」实际上由 batchmean 在 batch 维度取均值近似实现。严格来说应除以有效 token 数,但在固定序列长度下两者等价。

6.4.4 T² 缩放(前置坑之三)

最后乘上 T * T

kd = F.kl_div(...) * (T * T)

原理在第 1 章已详细推导:温度软化后梯度被压缩约 1/T,乘 T² 补偿,让软标签项的优化力度与硬标签项平衡。这一行绝不能漏,否则温度越高蒸馏信号越弱,等于没蒸馏。

6.5 组合成总损失

把两项组合,加权求和:

def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha): ce_loss = compute_ce_loss(student_logits, labels) kd_loss = compute_kd_loss(student_logits, teacher_logits, temperature) total_loss = alpha * ce_loss + (1.0 - alpha) * kd_loss return total_loss, ce_loss, kd_loss

返回三元组 (total, ce, kd) 是一个重要的工程设计:训练循环里可以同时记录三项,观察它们的相对变化,从而判断蒸馏是否健康(详见第 7 章)。

6.6 辅助函数:分布 KL(用于评估)

除了损失,我们还提供一个独立的「分布 KL」函数,用于评估阶段衡量学生学得有多像教师:

def distribution_kl(logits_a, logits_b, temperature=1.0): """计算两组 logits 的 softmax 分布的平均 KL(KL(a‖b))。""" T = max(float(temperature), 1e-3) soft_a = F.softmax(logits_a / T, dim=-1) log_soft_b = F.log_softmax(logits_b / T, dim=-1) kl = F.kl_div(log_soft_b, soft_a, reduction="batchmean") return kl.item()

它和 compute_kd_loss 的区别:

函数 用途 是否参与反向传播 温度
compute_kd_loss 训练损失 是(返回 tensor) 蒸馏温度 T
distribution_kl 评估指标 否(返回 float) 默认 1.0(原始分布)

评估时通常用 T=1(原始分布)来衡量,反映学生「真实输出」和教师差多远。这个函数在《第 8 章 评估》会用到。

6.7 数值验证:跑通损失函数

不需要真实模型,用随机数据就能验证损失函数的正确性:

import torch from losses import distillation_loss, distribution_kl B, T_len, V = 2, 5, 1000 # batch=2, 序列长=5, 词表=1000 student_logits = torch.randn(B, T_len, V) teacher_logits = torch.randn(B, T_len, V) labels = torch.randint(0, V, (B, T_len)) total, ce, kd = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5 ) print(f"total={total.item():.4f}, ce={ce.item():.4f}, kd={kd.item():.4f}") print(f"distribution_kl = {distribution_kl(student_logits, teacher_logits):.4f}")

健康性检查

几个验证点:

  1. 学生 = 教师时,KL 应接近 0
kd_same = distillation_loss( teacher_logits, teacher_logits, labels, temperature=2.0, alpha=0.5 )[2] print(f"学生=教师时的 kd: {kd_same.item():.6f}") # 应接近 0
  1. α=1 时退化为纯 CE
total_alpha1, _, _ = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=1.0 ) print(f"alpha=1 的 total: {total_alpha1.item():.4f}") # 应等于 ce
  1. α=0 时退化为纯 KL·T²
total_alpha0, _, _ = distillation_loss( student_logits, teacher_logits, labels, temperature=2.0, alpha=0.0 ) print(f"alpha=0 的 total: {total_alpha0.item():.4f}") # 应等于 kd

这三个检查能帮你确认实现无误。本项目的代码已通过这些验证。

6.8 三大坑总结

表现 正确做法
忘 shift CE、KL 全部错位,训练效果差但不报错 logits 截末位、labels 截首位
KL 方向传反 蒸馏方向错误,学生学歪 kl_div(学生log_softmax, 教师softmax)
漏乘 T² 温度越高信号越弱 KL 结果乘 T * T

这三个坑是蒸馏实现的"三大陷阱",掉进任何一个都会让效果大打折扣却难以察觉。本章把它们都讲透了,并在代码中明确标注。

本章小结

  • 蒸馏损失 = α·CE(硬标签)+ (1-α)·T²·KL(软标签),返回三元组便于监控。
  • shift 对齐:GPT2 的 logits 截末位、labels 截首位,保证 next-token 对齐。
  • KL 方向F.kl_div(学生 log_softmax, 教师 softmax),算的是 KL(教师‖学生)。
  • reduction:用 "batchmean",避免被词表维度稀释。
  • T² 缩放:KL 结果乘 T*T,补偿温度对梯度的压缩。
  • 损失函数可用随机数据独立验证,三个健康性检查确认实现无误。

动手实验:用随机数据跑一遍损失函数,并做 6.7 节的三个健康性检查。确认「学生=教师时 KL≈0」「α=1 退化为 CE」「α=0 退化为 KL·T²」。

下一站:损失函数就位,所有零件齐了。在《第 7 章 训练主循环》中,我们把数据、模型、损失、优化器组装成一个完整的训练循环,并实现损失分解日志与断点续训。


发布者: 作者: 转发
评论区 (0)
U