第 9 章 推理与采样生成 学生训练好了,是时候让它"开口说话"了。本章实现自回归文本生成,讲解温度采样与 top-k 截断如何控制生成质量,并提供交互式命令行。 9.1 自回归生成的本质 语言模型的生成是自回归的:每次预测一个 token,把它拼回输入,再预测下一个,循环往复。 核心循环只有三步:前向 → 取最后一个位置的分布 → 采样/贪心选一个 token → 拼回。 9.2 生成器的封装 把生成逻辑封装成一个类,对外暴露简洁接口: 注意它从 checkpoint 自动恢复配置(和评估模块一样的逻辑),所以调用者不需要手动传学生架构。 9.3 核心生成循环 逐步拆解几个关键设计。 9.3.
学生训练好了,是时候让它"开口说话"了。本章实现自回归文本生成,讲解温度采样与 top-k 截断如何控制生成质量,并提供交互式命令行。
语言模型的生成是自回归的:每次预测一个 token,把它拼回输入,再预测下一个,循环往复。
核心循环只有三步:前向 → 取最后一个位置的分布 → 采样/贪心选一个 token → 拼回。
把生成逻辑封装成一个类,对外暴露简洁接口:
import torch import torch.nn.functional as F class StudentGenerator: def __init__(self, checkpoint_path, cfg=None, device=None): # 加载学生模型 + 编码器 self.cfg = cfg or self._load_cfg_from_ckpt(checkpoint_path) self.encoder = get_encoder() self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.block_size = self.cfg.block_size self.model = load_student(checkpoint_path, self.cfg, str(self.device))
注意它从 checkpoint 自动恢复配置(和评估模块一样的逻辑),所以调用者不需要手动传学生架构。
@torch.no_grad() def generate(self, prompt, max_new_tokens=100, temperature=0.8, top_k=40): self.model.eval() ids = self.encoder.encode(prompt) x = torch.tensor([ids], dtype=torch.long, device=self.device) for _ in range(max_new_tokens): # ① 只保留最近 block_size 个 token,防超长 x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] logits = self.model(input_ids=x_cond).logits logits = logits[:, -1, :] / max(temperature, 1e-3) # 取最后一位 + 温度 # ② top-k 截断 if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") # ③ 采样 or 贪心 if temperature is not None and temperature > 0: probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: next_id = logits.argmax(dim=-1, keepdim=True) # ④ 拼回输入 x = torch.cat([x, next_id], dim=1) return self.encoder.decode(x[0].tolist())
逐步拆解几个关键设计。
logits = logits[:, -1, :] / temperature
logits[:, -1, :] 是序列最后一个位置的预测分布——它代表「在当前所有 token 之后,下一个 token 应该是什么」。整个序列前向只为了得到这一个位置。
x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:]
生成时输入会越来越长,可能超过模型的上下文长度 block_size。这里用滑动窗口:只保留最近的 block_size 个 token。虽然理论上会"忘掉"最早的内容,但对续写场景影响不大。
torch.no_grad()整个生成函数用 @torch.no_grad() 装饰,不构建计算图。生成时只前向不反传,省显存也更快。
生成时的 temperature(注意:这和蒸馏温度是不同概念!)控制输出的随机程度:
logits = logits / temperature
| 温度 | 效果 | 适用 |
|---|---|---|
temperature → 0 |
退化为贪心(总取 argmax) | 确定、重复、保守 |
temperature = 1.0 |
原始分布采样 | 平衡 |
temperature > 1.0 |
分布变平,更随机 | 多样、有创造性,但可能乱 |
概念辨析:蒸馏温度(训练时软化教师)和生成温度(推理时控制采样)虽然数学形式一样,但目的完全不同。蒸馏温度是为了让软标签更平坦以暴露暗知识;生成温度是为了控制生成文本的多样性。别混淆。
纯按概率采样有个问题:长尾词(概率极低的词)偶尔会被抽到,导致生成出不连贯的内容。Top-k 截断只保留概率最高的 k 个词,其余直接置为 -inf(概率归零):
if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf")
逻辑:
torch.topk 找出最高 k 个值中的最小值 v[:, [-1]](即第 k 大的值)。-inf,softmax 后这些位置概率为 0。| top_k | 效果 |
|---|---|
top_k = 1 |
等同贪心 |
top_k = 40(默认) |
平衡多样性与连贯性 |
top_k = vocab_size |
不截断,纯按原分布采样 |
if temperature > 0: # 采样 probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: # 贪心 next_id = logits.argmax(dim=-1, keepdim=True)
| 策略 | 特点 | 缺点 |
|---|---|---|
| 贪心(temperature ≤ 0) | 确定性,每次结果一样 | 容易陷入重复循环 |
| 采样(temperature > 0) | 多样、有创造性 | 偶有不连贯 |
实践中,温度采样 + top-k 截断是最常用的组合,既保证多样性又过滤低质词。
把生成器包成命令行工具,支持单次生成和交互模式:
def main(): args = parse_args() gen = StudentGenerator(args.checkpoint, device=device) if args.prompt: # 单次生成 text = gen.generate( args.prompt, args.max_new_tokens, args.temperature, args.top_k ) print(text) else: # 交互模式 while True: prompt = input("提示> ").strip() if prompt.lower() in ("quit", "exit"): break if prompt: gen.generate_and_print(prompt, ...)
使用方式:
# 单次生成 python inference.py --prompt "ROMEO:" --max-new-tokens 200 # 交互模式(不传 prompt) python inference.py # 然后依次输入提示,回车生成
交互模式适合快速试玩不同提示,观察学生模型的语言能力。
蒸馏出来的学生,生成质量取决于多个因素:
| 因素 | 影响 |
|---|---|
| 学生容量(层数/维度) | 越大语言能力越强,但越慢 |
| 蒸馏温度 T | 影响学生学到的分布平滑度 |
| 训练步数 | 越多越收敛,但要注意过拟合 |
| block_size | 越大长程依赖越好,但显存越大 |
| 生成温度 + top_k | 控制多样性与连贯性的平衡 |
一个小经验:蒸馏学生虽然参数少,但因为有教师软标签的引导,生成质量通常明显优于「同等大小、纯硬标签训练」的模型。这正是蒸馏的价值所在。
即使是小模型,合适的提示也能显著改善生成:
# 用剧本格式提示,引导学生生成剧本风格 python inference.py --prompt "ROMEO:\nBut soft, what light" # 用诗句开头,引导生成诗歌 python inference.py --prompt "Shall I compare thee" # 贪心解码,得到最"稳妥"的续写 python inference.py --prompt "To be, or" --temperature 0.01
由于学生主要在莎士比亚数据上蒸馏,用莎翁风格的提示能得到最连贯的续写。
动手实验:用同一个提示,分别用 --temperature 0.01(近贪心)和 --temperature 1.2(高随机)生成,对比输出的连贯性与多样性,直观感受生成温度的作用。
下一站:基础流程全部讲完了。在《第 10 章 进阶实战》中,我们讨论教师缓存的内部原理、温度与 α 的调参实验、以及如何把蒸馏扩展到中间层对齐等更高级的形式。