第 6 章 推理与采样(inference.py)


文档摘要

第 6 章 推理与采样(inference.py) 本章目标:搞懂自回归生成 + temperature + top-k 采样的实现细节。 6.1 章节地图 inference.py 干三件事: 6.2 完整源码 6.3 GPTGenerator:推理封装类 两个关键设置 禁用自动求导。推理时不需要反向传播,禁用后: 不存中间激活,显存占用大幅降低(可省 5-10 倍)。 速度提升(少了构建计算图的开销)。 关闭 dropout,保证推理结果确定性(除了采样本身)。 💡 这两个加在一起是推理的标配。忘任何一个都有问题:忘 nograd 显存爆;忘 eval 推理结果每次不同。

第 6 章 推理与采样(inference.py)

本章目标:搞懂自回归生成 + temperature + top-k 采样的实现细节。

6.1 章节地图

inference.py 干三件事:

┌─────────────────────────────────────────────────────────┐ │ 1. GPTGenerator 类 封装模型 + 编码器 + generate() │ │ 2. generate() 实现 自回归循环 + 三种采样策略 │ │ 3. load_for_inference 智能加载 checkpoint │ └─────────────────────────────────────────────────────────┘

6.2 完整源码

""" inference.py ============ 模型推理模块。 主要功能: 1. load_for_inference(): 加载训练好的模型与配置,返回可用于推理的对象。 2. generate(): 实现带 temperature 与 top_k 采样的自回归文本生成。 3. main(): 命令行交互入口,便于在终端快速测试。 """ import argparse from typing import Optional, Tuple import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel from config import GPTConfig from dataset import get_encoder from model import build_model class GPTGenerator: """封装模型与编码器,提供统一的生成接口。""" def __init__( self, model: GPT2LMHeadModel, gpt_config: GPTConfig, device: Optional[torch.device] = None, ): self.model = model self.config = gpt_config self.encoder = get_encoder() self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model.to(self.device) self.model.eval() @torch.no_grad() def generate( self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 40, ) -> str: """ 自回归生成文本。 参数: prompt: 输入提示文本。 max_new_tokens: 最多生成的新 token 数量。 temperature: 温度系数。越高越随机,越低越确定;<=0 视为贪心。 top_k: 只在概率最高的 top_k 个 token 中采样;<=0 表示不限制。 返回: 完整生成文本(包含 prompt)。 """ if not prompt: # 无 prompt 时使用 <|endoftext|> 作为起始 token。 input_ids = torch.tensor( [[self.config.vocab_size - 1]], dtype=torch.long, device=self.device ) else: tokens = self.encoder.encode(prompt) input_ids = torch.tensor([tokens], dtype=torch.long, device=self.device) block_size = self.config.block_size for _ in range(max_new_tokens): # 截取最后 block_size 个 token 作为上下文,避免超过模型最大长度。 cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:] logits = self.model(input_ids=cond).logits # (1, T, vocab) next_logits = logits[:, -1, :] # 仅取最后一个位置 (1, vocab) # ---------- temperature ---------- if temperature <= 0: # 贪心解码:直接取 argmax。 next_token = torch.argmax(next_logits, dim=-1, keepdim=True) else: next_logits = next_logits / temperature # ---------- top_k ---------- if top_k > 0: top_k = min(top_k, next_logits.size(-1)) # 取概率最高的 top_k 个 logits,其余设为 -inf 以屏蔽。 top_values, top_indices = torch.topk(next_logits, top_k, dim=-1) mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values) next_logits = mask probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # 拼接到序列末尾,继续下一轮预测。 input_ids = torch.cat([input_ids, next_token], dim=1) # 遇到结束符提前终止。 if next_token.item() == self.config.vocab_size - 1 and not prompt: break # 解码为文本。 output_ids = input_ids[0].tolist() return self.encoder.decode(output_ids) # -------------------------------------------------------------------------- # 加载入口 # -------------------------------------------------------------------------- def load_for_inference( checkpoint_path: str, device: Optional[str] = None, gpt_config: Optional[GPTConfig] = None, ) -> GPTGenerator: """ 加载 checkpoint 并返回 GPTGenerator。 参数: checkpoint_path: 权重文件路径或 transformers 目录。 device: 推理设备,None 表示自动选择。 gpt_config: 模型配置;若 checkpoint 中包含则会优先使用其配置。 """ import os target_device = torch.device( device or ("cuda" if torch.cuda.is_available() else "cpu") ) if os.path.isdir(checkpoint_path): # transformers 风格目录。 model = GPT2LMHeadModel.from_pretrained(checkpoint_path) if gpt_config is None: gpt_config = GPTConfig() else: ckpt = torch.load(checkpoint_path, map_location=target_device) # 优先使用 checkpoint 中保存的配置。 if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg elif gpt_config is None: gpt_config = GPTConfig() model = build_model(gpt_config) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) return GPTGenerator(model=model, gpt_config=gpt_config, device=target_device) # -------------------------------------------------------------------------- # 命令行入口 # -------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="GPT 推理脚本") p.add_argument( "--checkpoint", type=str, default="checkpoints/gpt_final.pt", help="checkpoint 文件路径或 transformers 模型目录", ) p.add_argument("--prompt", type=str, default="", help="输入提示文本") p.add_argument("--max-tokens", type=int, default=100, help="最多生成的新 token 数") p.add_argument("--temperature", type=float, default=0.8, help="采样温度") p.add_argument("--top-k", type=int, default=40, help="top-k 采样的 k 值") p.add_argument("--device", type=str, default=None, help="推理设备") return p.parse_args() def main(): args = parse_args() generator = load_for_inference(args.checkpoint, device=args.device) print(f"[inference] 已加载: {args.checkpoint}") print(f"[inference] prompt: {args.prompt!r}") text = generator.generate( prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print("\n===== 生成结果 =====") print(text) if __name__ == "__main__": main()

6.3 GPTGenerator:推理封装类

class GPTGenerator: def __init__(self, model, gpt_config, device=None): self.model = model self.config = gpt_config self.encoder = get_encoder() self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() # ★ 推理模式:关 dropout

两个关键设置

@torch.no_grad()

@torch.no_grad() def generate(self, ...):

禁用自动求导。推理时不需要反向传播,禁用后:

  • 不存中间激活,显存占用大幅降低(可省 5-10 倍)。
  • 速度提升(少了构建计算图的开销)。

model.eval()

self.model.eval()

关闭 dropout,保证推理结果确定性(除了采样本身)。

💡 这两个加在一起是推理的标配。忘任何一个都有问题:忘 no_grad 显存爆;忘 eval 推理结果每次不同。

为什么封装成类

不直接写函数,而是包一个 GPTGenerator 类,好处:

  1. 状态封装:模型、编码器、设备、配置都存在 self 上,不用每次都传一遍。
  2. 复用友好:Web UI(app.py)能持有一个 generator 实例反复调用。
  3. 接口清晰:外部只调 generate(prompt, ...),不关心内部细节。

6.4 自回归生成主循环

# 编码 prompt if not prompt: input_ids = torch.tensor([[self.config.vocab_size - 1]], ...) # <|endoftext|> else: tokens = self.encoder.encode(prompt) input_ids = torch.tensor([tokens], ...) for _ in range(max_new_tokens): # ① 截断到 block_size cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:] # ② 前向,取最后一位的 logits logits = self.model(input_ids=cond).logits next_logits = logits[:, -1, :] # ③ 采样(见 6.5) next_token = ... # ④ 拼接 input_ids = torch.cat([input_ids, next_token], dim=1) # ⑤ 遇到结束符提前终止 if next_token.item() == self.config.vocab_size - 1 and not prompt: break return self.encoder.decode(input_ids[0].tolist())

5 步逐个讲

步骤 ①:截断到 block_size

cond = input_ids if input_ids.size(1) <= block_size else input_ids[:, -block_size:]

模型最大上下文是 block_size。生成过程中序列越来越长,一旦超过 block_size,必须截掉前面的,只保留最后 block_size 个 token

input_ids[:, -block_size:] 是切片,负索引表示从末尾往前数。

⚠️ 这种朴素实现的性能问题:每生成一个 token 都要把整段 cond(最长 block_size)重新前向一遍,复杂度 O(block_size²)。生产中用 KV Cache(缓存每层 K/V 避免重算),HF 的 model.generate() 默认就用。本项目手写循环是为了教学清晰。

步骤 ②:取最后一位的 logits

logits = self.model(input_ids=cond).logits # shape (1, T, vocab) next_logits = logits[:, -1, :] # shape (1, vocab)

自回归生成时,我们只关心**「给定当前全部 token,下一个 token 是什么」**。位置 -1(最后一个 token)的 logits 就是这个预测。

前面位置的 logits 在生成场景下是「废物」(它们预测的是「已经知道的下一个」),丢弃即可。

步骤 ③:采样

见 6.5 节,三种策略。

步骤 ④:拼接

input_ids = torch.cat([input_ids, next_token], dim=1)

把新 token 拼到序列末尾,下一轮循环它就成了已知上下文。

步骤 ⑤:结束符终止

if next_token.item() == self.config.vocab_size - 1 and not prompt: break

只有「无 prompt 的自由生成」(即从 <|endoftext|> 起头)才会在生成出 <|endoftext|> 时停止

为什么给了 prompt 时不提前停?因为模型可能在中间误输出 <|endoftext|>(实际生成经常出现),提前停会让结果很短。

6.5 三种采样策略(重难点)

generate() 根据参数走不同分支:

分支 A:贪心解码(temperature <= 0

if temperature <= 0: next_token = torch.argmax(next_logits, dim=-1, keepdim=True)

每步取概率最大的 token。确定性强(同输入永远同输出),但容易陷入重复循环:

prompt: "To be, or" 输出: "To be, or not to be not to be not to be not to be..."

分支 B:温度采样(temperature > 0

next_logits = next_logits / temperature

温度的数学本质:softmax(logits/T)。

  • T → 0⁺:分布越来越尖,趋近 argmax(贪心)。
  • T = 1:原分布不变。
  • T → ∞:分布越来越平,趋近均匀随机(乱码)。

直觉:温度「平滑」或「锐化」概率分布。低温保守确定,高温奔放随机。

温度的数学演示

logits = [2.0, 1.0, 0.5] T=0.1: softmax([20, 10, 5]) → [0.9999, 0.0001, 0.0000] ← 几乎必选第一个 T=0.5: softmax([4, 2, 1]) → [0.86, 0.12, 0.04] T=1.0: softmax([2, 1, 0.5]) → [0.59, 0.22, 0.13] ← 偶尔选其他 T=2.0: softmax([1, 0.5, 0.25]) → [0.39, 0.28, 0.21] ← 经常选其他 T=10: softmax([0.2, 0.1, 0.05]) → [0.30, 0.27, 0.25] ← 接近均匀

分支 C:Top-K 过滤(top_k > 0

if top_k > 0: top_k = min(top_k, next_logits.size(-1)) top_values, top_indices = torch.topk(next_logits, top_k, dim=-1) mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values) next_logits = mask

Top-K 的作用

只保留概率最高的 K 个 token,其余 logits 设为 -inf(softmax 后概率为 0),再从这 K 个里按概率采样。

为什么要 top-k

纯温度采样的痛点:长尾。即使温度合适,词表里仍有大量「明显不对」的 token(如罕见的拼写错误),它们概率很小但偶尔被抽中会让生成崩坏。Top-K 把这些「无希望」的候选直接踢掉。

词表 50257 个 token 的概率分布(示意) 排序后:[0.30, 0.20, 0.15, 0.10, 0.05, ... 一堆 0.001 ...] top_k=40:只在前 40 个里采样,其余强制为 0 ▲ │ 这些是「合理候选」 ────────────┼────────────────────── │ │ 这些是「噪声」,丢弃

代码逐行

top_k = min(top_k, next_logits.size(-1))

钳制:万一用户传 top_k=99999,但词表只有 50257,取 min 防越界。

top_values, top_indices = torch.topk(next_logits, top_k, dim=-1)

torch.topk 返回两个张量:

  • top_values:最大的 K 个值本身
  • top_indices:它们在原 tensor 里的位置
mask = torch.full_like(next_logits, float("-inf")) mask.scatter_(1, top_indices, top_values)

构造一个全 -inf 的 mask,用 scatter_ 把 top K 个位置填回原值。效果:top K 位置的 logits 不变,其他位置变 -inf

probs = F.softmax(next_logits, dim=-1)

softmax:-inf 位置的 probability 是 0(因为 exp(-inf) = 0),top K 位置正常分布。

torch.multinomial 采样

next_token = torch.multinomial(probs, num_samples=1)

multinomialprobs 给的概率分布有放回抽样 1 次。这正是「按概率采一个 token」的标准实现。

三者组合的实际效果

组合 行为
temperature=0 贪心,确定但易重复
temperature=1, top_k=0 纯随机采样,混乱
temperature=0.8, top_k=40(默认) 推荐:有变化但不离谱
temperature=0.3, top_k=10 保守、聚焦、更连贯
temperature=1.2, top_k=80 多样、发散、易跑题

调参经验:

  • 写代码补全:低温度(0.2-0.4),需要确定性和正确性。
  • 写诗/创意:中温度(0.7-1.0),需要一点惊喜。
  • brainstorming:高温度(1.0-1.3),最大化多样性。
  • 几乎所有场景都开 top_k=40,防止采到离谱 token。

6.6 load_for_inference:智能加载

def load_for_inference(checkpoint_path, device=None, gpt_config=None): import os target_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) if os.path.isdir(checkpoint_path): # 格式 A:transformers 风格目录 model = GPT2LMHeadModel.from_pretrained(checkpoint_path) if gpt_config is None: gpt_config = GPTConfig() else: # 格式 B:本项目 .pt 文件 ckpt = torch.load(checkpoint_path, map_location=target_device) # 优先用 checkpoint 里保存的配置(重要!) if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg elif gpt_config is None: gpt_config = GPTConfig() model = build_model(gpt_config) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) return GPTGenerator(model=model, gpt_config=gpt_config, device=target_device)

智能配置恢复

最关键的几行:

if "gpt_config" in ckpt: cfg = GPTConfig() cfg.__dict__.update(ckpt["gpt_config"]) gpt_config = cfg

亮点:从 checkpoint 自动恢复配置。训练时你改了 n_layer=12,checkpoint 里就存了 n_layer=12;推理时不需要再传 --n-layer 12,加载器自己读出来。

这是 train.py 把配置存进 checkpoint 的回报(第 5 章 save_checkpoint 存了 gpt_config.__dict__)。

两种 checkpoint 格式

model.py:load_model 类似(第 4 章):

格式 加载方式
目录 from_pretrained(dir)
单文件 .pt 读 config → build_modelload_state_dict

6.7 命令行入口

def main(): args = parse_args() generator = load_for_inference(args.checkpoint, device=args.device) print(f"[inference] 已加载: {args.checkpoint}") print(f"[inference] prompt: {args.prompt!r}") text = generator.generate( prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, ) print("\n===== 生成结果 =====") print(text)

常用命令

# 默认配置 python inference.py --prompt "To be, or" # 贪心(确定) python inference.py --prompt "ROMEO:" --temperature 0 # 多样化 python inference.py --prompt "ROMEO:" --temperature 1.0 --top-k 50 --max-tokens 200 # 指定 checkpoint python inference.py --checkpoint checkpoints/gpt_step2000.pt --prompt "To be, or"

参数表

参数 类型 默认值 说明
--checkpoint str checkpoints/gpt_final.pt 权重路径或 transformers 目录
--prompt str "" 输入提示文本
--max-tokens int 100 最多生成的新 token 数
--temperature float 0.8 采样温度
--top-k int 40 top-k 采样的 k 值
--device str - 推理设备

6.8 Python 中调用

不通过命令行,直接 import:

from inference import load_for_inference gen = load_for_inference("checkpoints/gpt_final.pt") # 多次生成 for _ in range(3): text = gen.generate( prompt="To be, or", max_new_tokens=100, temperature=0.8, top_k=40, ) print(text) print("---")

这是 app.py 的用法(第 7 章)。

6.9 动手实验

  1. 确定性 vs 随机:同一个 prompt 跑 5 次:

    • --temperature 0:观察输出是否完全一致(贪心确定)。
    • --temperature 1.0 --top-k 0:观察输出是否几乎每次不同且混乱。
  2. 温度扫描:固定 prompt,跑这些温度,对比效果:

    python inference.py --prompt "ROMEO:" --temperature 0.2 python inference.py --prompt "ROMEO:" --temperature 0.5 python inference.py --prompt "ROMEO:" --temperature 0.8 python inference.py --prompt "ROMEO:" --temperature 1.2
  3. 长 prompt 测试:给一个超过 block_size 的长 prompt,观察生成时是否报错(应该不会,因为有截断)。

  4. 思考题:如何实现 top-p(nucleus)采样

    • 提示:不再固定 K,而是选概率累加 ≥ p 的最小集合。例如 p=0.9:把 token 按概率降序排列,从前往后累加,累加到 0.9 就停,只在这些里采样。
  5. 进阶:用 HF 的 model.generate() 对比本项目的手写循环:

    model = ... # 加载好的 GPT2LMHeadModel out = model.generate( input_ids=..., max_new_tokens=100, do_sample=True, temperature=0.8, top_k=40, use_cache=True, # KV cache,比本项目快很多 )

6.10 小结

  • GPTGenerator 封装模型+编码器+设备,提供 generate() 接口。
  • @torch.no_grad() + model.eval() 是推理标配(省显存 + 关 dropout)。
  • 自回归循环 5 步:截断 → 前向取末位 logits → 采样 → 拼接 → 检查结束符。
  • 三种采样:
    • 贪心(temperature<=0):确定,易重复
    • 温度(temperature>0):softmax(logits/T),控制随机性
    • Top-K:踢掉长尾噪声
  • 默认组合 temperature=0.8, top_k=40 是安全选择。
  • load_for_inference 智能从 checkpoint 恢复配置,无需手动传架构。

6.11 下一章

推理逻辑通了,去 第 7 章 Web UI 把它包成浏览器可访问的界面。


发布者: 作者: 转发
评论区 (0)
U