第 9 章 推理与采样生成


文档摘要

第 9 章 推理与采样生成 学生训练好了,是时候让它"开口说话"了。本章实现自回归文本生成,讲解温度采样与 top-k 截断如何控制生成质量,并提供交互式命令行。 9.1 自回归生成的本质 语言模型的生成是自回归的:每次预测一个 token,把它拼回输入,再预测下一个,循环往复。 核心循环只有三步:前向 → 取最后一个位置的分布 → 采样/贪心选一个 token → 拼回。 9.2 生成器的封装 把生成逻辑封装成一个类,对外暴露简洁接口: 注意它从 checkpoint 自动恢复配置(和评估模块一样的逻辑),所以调用者不需要手动传学生架构。 9.3 核心生成循环 逐步拆解几个关键设计。 9.3.

第 9 章 推理与采样生成

学生训练好了,是时候让它"开口说话"了。本章实现自回归文本生成,讲解温度采样与 top-k 截断如何控制生成质量,并提供交互式命令行。

9.1 自回归生成的本质

语言模型的生成是自回归的:每次预测一个 token,把它拼回输入,再预测下一个,循环往复。

核心循环只有三步:前向 → 取最后一个位置的分布 → 采样/贪心选一个 token → 拼回

9.2 生成器的封装

把生成逻辑封装成一个类,对外暴露简洁接口:

import torch import torch.nn.functional as F class StudentGenerator: def __init__(self, checkpoint_path, cfg=None, device=None): # 加载学生模型 + 编码器 self.cfg = cfg or self._load_cfg_from_ckpt(checkpoint_path) self.encoder = get_encoder() self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.block_size = self.cfg.block_size self.model = load_student(checkpoint_path, self.cfg, str(self.device))

注意它从 checkpoint 自动恢复配置(和评估模块一样的逻辑),所以调用者不需要手动传学生架构。

9.3 核心生成循环

@torch.no_grad() def generate(self, prompt, max_new_tokens=100, temperature=0.8, top_k=40): self.model.eval() ids = self.encoder.encode(prompt) x = torch.tensor([ids], dtype=torch.long, device=self.device) for _ in range(max_new_tokens): # ① 只保留最近 block_size 个 token,防超长 x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] logits = self.model(input_ids=x_cond).logits logits = logits[:, -1, :] / max(temperature, 1e-3) # 取最后一位 + 温度 # ② top-k 截断 if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") # ③ 采样 or 贪心 if temperature is not None and temperature > 0: probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: next_id = logits.argmax(dim=-1, keepdim=True) # ④ 拼回输入 x = torch.cat([x, next_id], dim=1) return self.encoder.decode(x[0].tolist())

逐步拆解几个关键设计。

9.3.1 取最后一个位置

logits = logits[:, -1, :] / temperature

logits[:, -1, :] 是序列最后一个位置的预测分布——它代表「在当前所有 token 之后,下一个 token 应该是什么」。整个序列前向只为了得到这一个位置。

9.3.2 滑动窗口防超长

x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:]

生成时输入会越来越长,可能超过模型的上下文长度 block_size。这里用滑动窗口:只保留最近的 block_size 个 token。虽然理论上会"忘掉"最早的内容,但对续写场景影响不大。

9.3.3 torch.no_grad()

整个生成函数用 @torch.no_grad() 装饰,不构建计算图。生成时只前向不反传,省显存也更快。

9.4 温度采样:控制随机性

生成时的 temperature(注意:这和蒸馏温度是不同概念!)控制输出的随机程度:

logits = logits / temperature
温度 效果 适用
temperature → 0 退化为贪心(总取 argmax) 确定、重复、保守
temperature = 1.0 原始分布采样 平衡
temperature > 1.0 分布变平,更随机 多样、有创造性,但可能乱

概念辨析:蒸馏温度(训练时软化教师)和生成温度(推理时控制采样)虽然数学形式一样,但目的完全不同。蒸馏温度是为了让软标签更平坦以暴露暗知识;生成温度是为了控制生成文本的多样性。别混淆。

9.5 Top-k 截断:避免低质 token

纯按概率采样有个问题:长尾词(概率极低的词)偶尔会被抽到,导致生成出不连贯的内容。Top-k 截断只保留概率最高的 k 个词,其余直接置为 -inf(概率归零):

if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf")

逻辑:

  1. torch.topk 找出最高 k 个值中的最小值 v[:, [-1]](即第 k 大的值)。
  2. 把所有小于它的 logits 设为 -inf,softmax 后这些位置概率为 0。
  3. 之后的采样只在 top-k 范围内进行。
top_k 效果
top_k = 1 等同贪心
top_k = 40(默认) 平衡多样性与连贯性
top_k = vocab_size 不截断,纯按原分布采样

9.6 贪心 vs 采样的取舍

if temperature > 0: # 采样 probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: # 贪心 next_id = logits.argmax(dim=-1, keepdim=True)
策略 特点 缺点
贪心(temperature ≤ 0) 确定性,每次结果一样 容易陷入重复循环
采样(temperature > 0) 多样、有创造性 偶有不连贯

实践中,温度采样 + top-k 截断是最常用的组合,既保证多样性又过滤低质词。

9.7 交互式命令行

把生成器包成命令行工具,支持单次生成和交互模式:

def main(): args = parse_args() gen = StudentGenerator(args.checkpoint, device=device) if args.prompt: # 单次生成 text = gen.generate( args.prompt, args.max_new_tokens, args.temperature, args.top_k ) print(text) else: # 交互模式 while True: prompt = input("提示> ").strip() if prompt.lower() in ("quit", "exit"): break if prompt: gen.generate_and_print(prompt, ...)

使用方式:

# 单次生成 python inference.py --prompt "ROMEO:" --max-new-tokens 200 # 交互模式(不传 prompt) python inference.py # 然后依次输入提示,回车生成

交互模式适合快速试玩不同提示,观察学生模型的语言能力。

9.8 生成质量的影响因素

蒸馏出来的学生,生成质量取决于多个因素:

因素 影响
学生容量(层数/维度) 越大语言能力越强,但越慢
蒸馏温度 T 影响学生学到的分布平滑度
训练步数 越多越收敛,但要注意过拟合
block_size 越大长程依赖越好,但显存越大
生成温度 + top_k 控制多样性与连贯性的平衡

一个小经验:蒸馏学生虽然参数少,但因为有教师软标签的引导,生成质量通常明显优于「同等大小、纯硬标签训练」的模型。这正是蒸馏的价值所在。

9.9 提示工程小贴士

即使是小模型,合适的提示也能显著改善生成:

# 用剧本格式提示,引导学生生成剧本风格 python inference.py --prompt "ROMEO:\nBut soft, what light" # 用诗句开头,引导生成诗歌 python inference.py --prompt "Shall I compare thee" # 贪心解码,得到最"稳妥"的续写 python inference.py --prompt "To be, or" --temperature 0.01

由于学生主要在莎士比亚数据上蒸馏,用莎翁风格的提示能得到最连贯的续写。

本章小结

  • 自回归生成 = 循环「前向 → 取末位分布 → 选 token → 拼回」。
  • 滑动窗口:只保留最近 block_size 个 token,防超长。
  • 生成温度:控制采样随机性,与蒸馏温度是不同概念。
  • Top-k 截断:过滤长尾低质词,提升连贯性。
  • 贪心 vs 采样:贪心确定但易重复,采样多样但偶尔不连贯。
  • 实践常用组合:温度采样(0.7~1.0)+ top-k(40)。

动手实验:用同一个提示,分别用 --temperature 0.01(近贪心)和 --temperature 1.2(高随机)生成,对比输出的连贯性与多样性,直观感受生成温度的作用。

下一站:基础流程全部讲完了。在《第 10 章 进阶实战》中,我们讨论教师缓存的内部原理、温度与 α 的调参实验、以及如何把蒸馏扩展到中间层对齐等更高级的形式。


发布者: 作者: 转发
评论区 (0)
U