第 5 章 训练循环(train.py)· 上:流程总览与完整源码


文档摘要

第 5 章 训练循环(train.py)· 上:流程总览与完整源码 本章最重,分上下两篇。本篇先看训练循环的全景与完整源码,下篇逐行讲解。 5.1 训练流程全景 函数的 8 个阶段: 5.2 完整源码 💡 上面为可读性精简了 / (详见项目源码),核心逻辑完整保留。 5.3 8 个阶段速读 阶段 | 关键函数/调用 | 工程要点 1. 配置 | + | 命令行覆盖 dataclass 默认值 2. 种子/设备 | + | 固定 4 个随机源;auto 检测 CUDA 3. 数据 | + | 见第 3 章 4. 模型 | + | 见第 4 章 5. 优化器 | + | betas=(0.9, 0.95) GPT 经验值 6. 续训 | + 3 个 | 模型+优化器+调度器必须一起恢复 7.

第 5 章 训练循环(train.py)· 上:流程总览与完整源码

本章最重,分上下两篇。本篇先看训练循环的全景与完整源码,下篇逐行讲解。

5.1 训练流程全景

train() 函数的 8 个阶段:

1. 解析参数,覆盖默认配置 2. 设置随机种子 + 选择设备 3. 加载数据 + 构建 DataLoader 4. 构建模型 + 迁到设备 5. 构建优化器 (AdamW) + 调度器 (余弦退火) 6.(可选)从 checkpoint 恢复 7. 训练循环(前向 → 反向 → 裁剪 → 步进 → 日志 → 存盘) 8. 训练结束保存 final 模型

5.2 完整源码

""" train.py ======== GPT 模型训练入口脚本。 完整流程: 1. 解析命令行参数,覆盖默认配置。 2. 设置随机种子,保证可复现。 3. 加载数据集并构建 DataLoader。 4. 构建模型,迁移到设备。 5. 构建 AdamW 优化器 + 余弦退火学习率调度器。 6. 训练循环:梯度裁剪、tqdm 进度条、定期日志、定期保存 checkpoint。 """ import argparse import math import os import random import time from pathlib import Path import numpy as np import torch from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from config import GPTConfig, TrainConfig from dataset import get_dataset, build_dataloader from model import build_model, count_parameters # -------------------------------------------------------------------------- # 工具函数 # -------------------------------------------------------------------------- def set_seed(seed: int) -> None: """固定 random / numpy / torch 的随机种子,保证实验可复现。""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def select_device(name: str) -> torch.device: """根据配置字符串选择训练设备。""" if name == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(name) def get_cosine_schedule_with_warmup( optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, min_lr_ratio: float = 0.1, ) -> LambdaLR: """自定义"线性预热 + 余弦退火"学习率调度器。""" min_lr_ratio = max(0.0, min(1.0, min_lr_ratio)) def lr_lambda(current_step: int) -> float: # 1) 预热阶段:线性升温。 if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) # 2) 余弦退火阶段:从 1.0 平滑下降到 min_lr_ratio。 progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr_ratio + (1.0 - min_lr_ratio) * cosine return LambdaLR(optimizer, lr_lambda) def save_checkpoint( model, optimizer, scheduler, gpt_config, train_config, step, loss, checkpoint_dir, ) -> str: """将模型权重、优化器状态与配置保存到磁盘,返回保存路径。""" ckpt_dir = Path(checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) path = ckpt_dir / f"gpt_step{step}.pt" torch.save({ "step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": loss, "gpt_config": gpt_config.__dict__, "train_config": train_config.__dict__, }, path) print(f"[train] checkpoint 已保存: {path}") return str(path) # -------------------------------------------------------------------------- # 主训练循环 # -------------------------------------------------------------------------- def train(): args = parse_args() gpt_config = GPTConfig() train_config = TrainConfig() apply_args_to_configs(args, gpt_config, train_config) set_seed(train_config.seed) device = select_device(train_config.device) print(f"[train] 使用设备: {device}") text = get_dataset() dataloader = build_dataloader( text=text, block_size=gpt_config.block_size, batch_size=train_config.batch_size, num_workers=train_config.num_workers, shuffle=True, ) model = build_model(gpt_config) model.to(device) print(f"[train] 模型可训练参数量: {count_parameters(model)/1e6:.2f}M") optimizer = AdamW( model.parameters(), lr=train_config.learning_rate, betas=train_config.betas, weight_decay=train_config.weight_decay, ) scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=train_config.warmup_iters, num_training_steps=train_config.max_iters, min_lr_ratio=train_config.min_lr_ratio, ) start_step = 0 if args.resume: ckpt = torch.load(args.resume, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_step = ckpt["step"] + 1 model.train() pbar = tqdm(total=train_config.max_iters, initial=start_step, desc="训练", unit="step") step = start_step data_iter = iter(dataloader) t0 = time.time() running_loss = 0.0 while step < train_config.max_iters: try: x, y = next(data_iter) except StopIteration: data_iter = iter(dataloader) x, y = next(data_iter) x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) outputs = model(input_ids=x, labels=y) loss = outputs.loss optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip) optimizer.step() scheduler.step() running_loss += loss.item() step += 1 pbar.update(1) pbar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{scheduler.get_last_lr()[0]:.2e}"}) if step % train_config.log_iter == 0: avg_loss = running_loss / train_config.log_iter elapsed = time.time() - t0 speed = (step - start_step) / max(elapsed, 1e-6) print(f"[train] step {step}/{train_config.max_iters} | avg_loss={avg_loss:.4f} | speed={speed:.2f} step/s") running_loss = 0.0 if step % train_config.save_iter == 0 or step == train_config.max_iters: save_checkpoint(model, optimizer, scheduler, gpt_config, train_config, step, loss.item(), train_config.checkpoint_dir) pbar.close() final_path = Path(train_config.checkpoint_dir) / "gpt_final.pt" torch.save({ "step": step, "model_state_dict": model.state_dict(), "gpt_config": gpt_config.__dict__, "train_config": train_config.__dict__, "loss": loss.item(), }, final_path) print(f"[train] 最终模型已保存: {final_path}") if __name__ == "__main__": train()

💡 上面为可读性精简了 parse_args / apply_args_to_configs(详见项目源码),核心逻辑完整保留。

5.3 8 个阶段速读

阶段 关键函数/调用 工程要点
1. 配置 parse_args + apply_args_to_configs 命令行覆盖 dataclass 默认值
2. 种子/设备 set_seed + select_device 固定 4 个随机源;auto 检测 CUDA
3. 数据 get_dataset + build_dataloader 见第 3 章
4. 模型 build_model + model.to(device) 见第 4 章
5. 优化器 AdamW(...) + get_cosine_schedule_with_warmup betas=(0.9, 0.95) GPT 经验值
6. 续训 torch.load + 3 个 load_state_dict 模型+优化器+调度器必须一起恢复
7. 循环 while step < max_iters: 9 个细节,见下篇
8. 终存 torch.save final 不存优化器(推理不需要)

5.4 下一篇

源码看完了,去 第 5 章 下:逐行讲解 看每个函数的细节、数学推导、9 个工程要点。

第 6 章 推理与采样 也可以。


发布者: 作者: 转发
评论区 (0)
U