第 6 章 推理与采样(inference.py) 本章目标:搞懂自回归生成 + temperature + top-k 采样的实现细节。 6.1 章节地图 inference.py 干三件事: 6.2 完整源码 6.3 GPTGenerator:推理封装类 两个关键设置 禁用自动求导。推理时不需要反向传播,禁用后: 不存中间激活,显存占用大幅降低(可省 5-10 倍)。 速度提升(少了构建计算图的开销)。 关闭 dropout,保证推理结果确定性(除了采样本身)。 💡 这两个加在一起是推理的标配。忘任何一个都有问题:忘 nograd 显存爆;忘 eval 推理结果每次不同。
本章目标:搞懂自回归生成 + temperature + top-k 采样的实现细节。
inference.py 干三件事:
┌─────────────────────────────────────────────────────────┐ │ 1. GPTGenerator 类 封装模型 + 编码器 + generate() │ │ 2. generate() 实现 自回归循环 + 三种采样策略 │ │ 3. load_for_inference 智能加载 checkpoint │ └─────────────────────────────────────────────────────────┘
""" inference.py ============ 模型推理模块。 主要功能: 1. load_for_inference(): 加载训练好的模型与配置,返回可用于推理的对象。 2. generate(): 实现带 temperature 与 top_k 采样的自回归文本生成。 3. main(): 命令行交互入口,便于在终端快速测试。 """ import argparse from typing import Optional, Tuple import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel from config import GPTConfig from dataset import get_encoder from model import build_model class GPTGenerator: """封装模型与编码器,提供统一的生成接口。""" def __init__( self, model: GPT2LMHeadModel, gpt_config: GPTConfig, device: Optional[torch.device] = None, ): self.model = model self.config = gpt_config self.encoder = get_encoder() self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model.to(self.device) self.model.eval() @torch.no_grad() def generate( self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 40, ) -> str: """ 自回归生成文本。 参数: prompt: 输入提示文本。 max_new_tokens: 最多生成的新 token 数量。 temperature: 温度系数。越高越随机,越低越确定;<=0 视为贪心。 top_k: 只在概率最高的 top_k 个 token 中采样;<=0 表示不限制。 返回: 完整生成文本(包含 prompt)。 """ if not prompt: # 无 prompt 时使用 <|endoftext|> 作为起始 token。 input_ids = torch.tensor( [[self.config.vocab_size - 1]], dtype=torch.long, device=self.device ) else: tokens = self.encoder.encode(prompt) input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) block_size = self.config.block_size for _ in range(max_new_tokens): # 截取最后 block_size 个 token 作为上下文,避免超过模型最大长度。 cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:] logits = self.model(input_ids=cond).logits # (1, T, vocab) next_logits = logits[:, -1, :] # 仅取最后一个位置 (1, vocab) # ---------- temperature ---------- if temperature <= 0: # 贪心解码:直接取 argmax。 next_token = torch.argmax(next_logits, dim=-1, keepdim=True) else: next_logits = next_logits / temperature # ---------- top_k ---------- if top_k > 0: top_k = min(top_k, next_logits.size(-1)) # 取概率最高的 top_k 个 logits,其余设为 -inf 以屏蔽。 top_values, top_indices = torch.topk(next_logits, top_k, dim=-1) mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values) next_logits = mask probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # 拼接到序列末尾,继续下一轮预测。 input_ids = torch.cat([input_ids, next_token], dim=1) # 遇到结束符提前终止。 if next_token.item() == self.config.vocab_size - 1 and not prompt: break # 解码为文本。 output_ids = input_ids[0].tolist() return self.encoder.decode(output_ids) # -------------------------------------------------------------------------- # 加载入口 # -------------------------------------------------------------------------- def load_for_inference( checkpoint_path: str, device: Optional[str] = None, gpt_config: Optional[GPTConfig] = None, ) -> GPTGenerator: """ 加载 checkpoint 并返回 GPTGenerator。 参数: checkpoint_path: 权重文件路径或 transformers 目录。 device: 推理设备,None 表示自动选择。 gpt_config: 模型配置;若 checkpoint 中包含则会优先使用其配置。 """ import os target_device = torch.device( device or ("cuda" if torch.cuda.is_available() else "cpu") ) if os.path.isdir(checkpoint_path): # transformers 风格目录。 model = GPT2LMHeadModel.from_pretrained(checkpoint_path) if gpt_config is None: gpt_config = GPTConfig() else: ckpt = torch.load(checkpoint_path, map_location=target_device) # 优先使用 checkpoint 中保存的配置。 if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg elif gpt_config is None: gpt_config = GPTConfig() model = build_model(gpt_config) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) return GPTGenerator(model=model, gpt_config=gpt_config, device=target_device) # -------------------------------------------------------------------------- # 命令行入口 # -------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="GPT 推理脚本") p.add_argument( "--checkpoint", type=str, default="checkpoints/gpt_final.pt", help="checkpoint 文件路径或 transformers 模型目录", ) p.add_argument("--prompt", type=str, default="", help="输入提示文本") p.add_argument("--max-tokens", type=int, default=100, help="最多生成的新 token 数") p.add_argument("--temperature", type=float, default=0.8, help="采样温度") p.add_argument("--top-k", type=int, default=40, help="top-k 采样的 k 值") p.add_argument("--device", type=str, default=None, help="推理设备") return p.parse_args() def main(): args = parse_args() generator = load_for_inference(args.checkpoint, device=args.device) print(f"[inference] 已加载: {args.checkpoint}") print(f"[inference] prompt: {args.prompt!r}") text = generator.generate( prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print("\n===== 生成结果 =====") print(text) if __name__ == "__main__": main()
class GPTGenerator: def __init__(self, model, gpt_config, device=None): self.model = model self.config = gpt_config self.encoder = get_encoder() self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() # ★ 推理模式:关 dropout
@torch.no_grad()@torch.no_grad() def generate(self, ...):
禁用自动求导。推理时不需要反向传播,禁用后:
model.eval()self.model.eval()
关闭 dropout,保证推理结果确定性(除了采样本身)。
💡 这两个加在一起是推理的标配。忘任何一个都有问题:忘 no_grad 显存爆;忘 eval 推理结果每次不同。
不直接写函数,而是包一个 GPTGenerator 类,好处:
app.py)能持有一个 generator 实例反复调用。generate(prompt, ...),不关心内部细节。# 编码 prompt if not prompt: input_ids = torch.tensor([[self.config.vocab_size - 1]], ...) # <|endoftext|> else: tokens = self.encoder.encode(prompt) input_ids = torch.tensor([tokens], ...) for _ in range(max_new_tokens): # ① 截断到 block_size cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:] # ② 前向,取最后一位的 logits logits = self.model(input_ids=cond).logits next_logits = logits[:, -1, :] # ③ 采样(见 6.5) next_token = ... # ④ 拼接 input_ids = torch.cat([input_ids, next_token], dim=1) # ⑤ 遇到结束符提前终止 if next_token.item() == self.config.vocab_size - 1 and not prompt: break return self.encoder.decode(input_ids[0].tolist())
cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:]
模型最大上下文是 block_size。生成过程中序列越来越长,一旦超过 block_size,必须截掉前面的,只保留最后 block_size 个 token。
input_ids[:, -block_size:] 是切片,负索引表示从末尾往前数。
⚠️ 这种朴素实现的性能问题:每生成一个 token 都要把整段
cond(最长 block_size)重新前向一遍,复杂度 O(block_size²)。生产中用 KV Cache(缓存每层 K/V 避免重算),HF 的model.generate()默认就用。本项目手写循环是为了教学清晰。
logits = self.model(input_ids=cond).logits # shape (1, T, vocab) next_logits = logits[:, -1, :] # shape (1, vocab)
自回归生成时,我们只关心**「给定当前全部 token,下一个 token 是什么」**。位置 -1(最后一个 token)的 logits 就是这个预测。
前面位置的 logits 在生成场景下是「废物」(它们预测的是「已经知道的下一个」),丢弃即可。
见 6.5 节,三种策略。
input_ids = torch.cat([input_ids, next_token], dim=1)
把新 token 拼到序列末尾,下一轮循环它就成了已知上下文。
if next_token.item() == self.config.vocab_size - 1 and not prompt: break
只有「无 prompt 的自由生成」(即从 <|endoftext|> 起头)才会在生成出 <|endoftext|> 时停止。
为什么给了 prompt 时不提前停?因为模型可能在中间误输出 <|endoftext|>(实际生成经常出现),提前停会让结果很短。
generate() 根据参数走不同分支:
temperature <= 0)if temperature <= 0: next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
每步取概率最大的 token。确定性强(同输入永远同输出),但容易陷入重复循环:
prompt: "To be, or" 输出: "To be, or not to be not to be not to be not to be..."
temperature > 0)next_logits = next_logits / temperature
温度的数学本质:softmax(logits/T)。
直觉:温度「平滑」或「锐化」概率分布。低温保守确定,高温奔放随机。
logits = [2.0, 1.0, 0.5] T=0.1: softmax([20, 10, 5]) → [0.9999, 0.0001, 0.0000] ← 几乎必选第一个 T=0.5: softmax([4, 2, 1]) → [0.86, 0.12, 0.04] T=1.0: softmax([2, 1, 0.5]) → [0.59, 0.22, 0.13] ← 偶尔选其他 T=2.0: softmax([1, 0.5, 0.25]) → [0.39, 0.28, 0.21] ← 经常选其他 T=10: softmax([0.2, 0.1, 0.05]) → [0.30, 0.27, 0.25] ← 接近均匀
top_k > 0)if top_k > 0: top_k = min(top_k, next_logits.size(-1)) top_values, top_indices = torch.topk(next_logits, top_k, dim=-1) mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values) next_logits = mask
只保留概率最高的 K 个 token,其余 logits 设为 -inf(softmax 后概率为 0),再从这 K 个里按概率采样。
纯温度采样的痛点:长尾。即使温度合适,词表里仍有大量「明显不对」的 token(如罕见的拼写错误),它们概率很小但偶尔被抽中会让生成崩坏。Top-K 把这些「无希望」的候选直接踢掉。
词表 50257 个 token 的概率分布(示意) 排序后:[0.30, 0.20, 0.15, 0.10, 0.05, ... 一堆 0.001 ...] top_k=40:只在前 40 个里采样,其余强制为 0 ▲ │ 这些是「合理候选」 ────────────┼────────────────────── │ │ 这些是「噪声」,丢弃
top_k = min(top_k, next_logits.size(-1))
钳制:万一用户传 top_k=99999,但词表只有 50257,取 min 防越界。
top_values, top_indices = torch.topk(next_logits, top_k, dim=-1)
torch.topk 返回两个张量:
top_values:最大的 K 个值本身top_indices:它们在原 tensor 里的位置mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values)
构造一个全 -inf 的 mask,用 scatter_ 把 top K 个位置填回原值。效果:top K 位置的 logits 不变,其他位置变 -inf。
probs = F.softmax(next_logits, dim=-1)
softmax:-inf 位置的 probability 是 0(因为 exp(-inf) = 0),top K 位置正常分布。
torch.multinomial 采样next_token = torch.multinomial(probs, num_samples=1)
multinomial 按 probs 给的概率分布有放回抽样 1 次。这正是「按概率采一个 token」的标准实现。
| 组合 | 行为 |
|---|---|
temperature=0 |
贪心,确定但易重复 |
temperature=1, top_k=0 |
纯随机采样,混乱 |
temperature=0.8, top_k=40(默认) |
推荐:有变化但不离谱 |
temperature=0.3, top_k=10 |
保守、聚焦、更连贯 |
temperature=1.2, top_k=80 |
多样、发散、易跑题 |
调参经验:
def load_for_inference(checkpoint_path, device=None, gpt_config=None): import os target_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) if os.path.isdir(checkpoint_path): # 格式 A:transformers 风格目录 model = GPT2LMHeadModel.from_pretrained(checkpoint_path) if gpt_config is None: gpt_config = GPTConfig() else: # 格式 B:本项目 .pt 文件 ckpt = torch.load(checkpoint_path, map_location=target_device) # 优先用 checkpoint 里保存的配置(重要!) if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg elif gpt_config is None: gpt_config = GPTConfig() model = build_model(gpt_config) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) return GPTGenerator(model=model, gpt_config=gpt_config, device=target_device)
最关键的几行:
if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg
亮点:从 checkpoint 自动恢复配置。训练时你改了 n_layer=12,checkpoint 里就存了 n_layer=12;推理时不需要再传 --n-layer 12,加载器自己读出来。
这是 train.py 把配置存进 checkpoint 的回报(第 5 章 save_checkpoint 存了 gpt_config.__dict__)。
与 model.py:load_model 类似(第 4 章):
| 格式 | 加载方式 |
|---|---|
| 目录 | from_pretrained(dir) |
单文件 .pt |
读 config → build_model → load_state_dict |
def main(): args = parse_args() generator = load_for_inference(args.checkpoint, device=args.device) print(f"[inference] 已加载: {args.checkpoint}") print(f"[inference] prompt: {args.prompt!r}") text = generator.generate( prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print("\n===== 生成结果 =====") print(text)
# 默认配置 python inference.py --prompt "To be, or" # 贪心(确定) python inference.py --prompt "ROMEO:" --temperature 0 # 多样化 python inference.py --prompt "ROMEO:" --temperature 1.0 --top-k 50 --max-tokens 200 # 指定 checkpoint python inference.py --checkpoint checkpoints/gpt_step2000.pt --prompt "To be, or"
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--checkpoint |
str | checkpoints/gpt_final.pt |
权重路径或 transformers 目录 |
--prompt |
str | "" |
输入提示文本 |
--max-tokens |
int | 100 | 最多生成的新 token 数 |
--temperature |
float | 0.8 | 采样温度 |
--top-k |
int | 40 | top-k 采样的 k 值 |
--device |
str | - | 推理设备 |
不通过命令行,直接 import:
from inference import load_for_inference gen = load_for_inference("checkpoints/gpt_final.pt") # 多次生成 for _ in range(3): text = gen.generate( prompt="To be, or", max_new_tokens=100, temperature=0.8, top_k=40, ) print(text) print("---")
这是 app.py 的用法(第 7 章)。
确定性 vs 随机:同一个 prompt 跑 5 次:
--temperature 0:观察输出是否完全一致(贪心确定)。--temperature 1.0 --top-k 0:观察输出是否几乎每次不同且混乱。温度扫描:固定 prompt,跑这些温度,对比效果:
python inference.py --prompt "ROMEO:" --temperature 0.2 python inference.py --prompt "ROMEO:" --temperature 0.5 python inference.py --prompt "ROMEO:" --temperature 0.8 python inference.py --prompt "ROMEO:" --temperature 1.2
长 prompt 测试:给一个超过 block_size 的长 prompt,观察生成时是否报错(应该不会,因为有截断)。
思考题:如何实现 top-p(nucleus)采样?
进阶:用 HF 的 model.generate() 对比本项目的手写循环:
model = ... # 加载好的 GPT2LMHeadModel out = model.generate( input_ids=..., max_new_tokens=100, do_sample=True, temperature=0.8, top_k=40, use_cache=True, # KV cache,比本项目快很多 )
GPTGenerator 封装模型+编码器+设备,提供 generate() 接口。@torch.no_grad() + model.eval() 是推理标配(省显存 + 关 dropout)。temperature<=0):确定,易重复temperature>0):softmax(logits/T),控制随机性temperature=0.8, top_k=40 是安全选择。load_for_inference 智能从 checkpoint 恢复配置,无需手动传架构。推理逻辑通了,去 第 7 章 Web UI 把它包成浏览器可访问的界面。