nano-GPT 模型架构介绍


文档摘要

nano-GPT 模型架构介绍 本章节将围绕nano-GPT模型的代码实现,从核心特性、配置参数、关键组件、运行流程到优化推理,系统化的拆解模型架构,从原理+实现的角度,详细介绍 nano-GPT 模型架构。 如果对于相关知识没有太多基础,想要从基础整体了解,可以先看补充内容:大语言模型架构训练核心原理全解析。正文内容更偏向于其核心要点内容的说明。 一、 模型概述 nanoChat模型是轻量简化版 GPT,基于 Transformer 解码器架构,在保留核心 autoregressive 能力的同时,通过多项工程优化降低复杂度并提升推理效率,核心特性如下: 采用旋转位置编码(Rotary Embeddings),移除传统可学习位置编码,减少参数并增强长序列泛化性;

nano-GPT 模型架构介绍

本章节将围绕nano-GPT模型的代码实现,从核心特性、配置参数、关键组件、运行流程到优化推理,系统化的拆解模型架构,从原理+实现的角度,详细介绍 nano-GPT 模型架构。

如果对于相关知识没有太多基础,想要从基础整体了解,可以先看补充内容:大语言模型架构训练核心原理全解析。正文内容更偏向于其核心要点内容的说明。

一、 模型概述

nanoChat模型是轻量简化版 GPT,基于 Transformer 解码器架构,在保留核心 autoregressive 能力的同时,通过多项工程优化降低复杂度并提升推理效率,核心特性如下:

  • 采用旋转位置编码(Rotary Embeddings),移除传统可学习位置编码,减少参数并增强长序列泛化性;

  • 引入QK 归一化(QK Norm),对注意力层的查询(Q)和键(K)单独做归一化,稳定注意力计算;

  • 解耦词嵌入与输出层权重(Untied Weights),词嵌入层(wte)与语言模型头(lm_head)使用独立权重,提升训练灵活性;

  • MLP 层采用ReLU² 激活函数(ReLU 后平方),替代传统 GELU,简化计算同时保留非线性表达能力;

  • 使用无参数 RMS 归一化,所有归一化操作(norm 函数)无 learnable params,降低过拟合风险;

  • 全链路无偏置线性层,所有 Linear 层(如 c_q、c_k、c_proj)均不设 bias,减少参数总量;

  • 支持多查询注意力(MQA),通过重复键值头(KV Head)匹配查询头(Q Head)数量,大幅提升推理效率。

二、 核心配置(GPTConfig)

模型的所有可配置参数通过GPTConfig数据类定义,决定模型规模、结构与运行边界,参数说明如下表:

参数名 默认值 作用说明
sequence_len 1024 训练阶段支持的最大序列长度,推理时可通过 KV 缓存处理更长序列(需小于 rotary_seq_len)
vocab_size 50304 模型支持的词汇表大小,对应词嵌入层(wte)的输入维度
n_layer 12 Transformer 解码器层数,即 Block 模块的堆叠数量
n_head 6 注意力层的查询头(Q Head)数量,决定注意力的并行度
n_kv_head 6 注意力层的键值头(KV Head)数量,MQA 模式下需满足n_head % n_kv_head == 0
n_embd 768 词嵌入维度与 Transformer 层的隐藏状态维度,需满足n_embd % n_head == 0

三、 关键组件拆解

模型由「基础工具函数」和「核心模块」构成,各组件职责明确且复用性强,以下按依赖关系拆解:

3.1 基础工具函数

为核心模块提供通用计算能力,无状态且纯函数式设计:

  1. RMS 归一化(norm)
  • 功能:对输入张量的最后一维(特征维度)做 RMS 归一化,公式为 x / sqrt(mean(x²) + ε)

  • 特点:无 learnable params,仅依赖输入数据本身,避免参数冗余。

  1. 旋转位置编码应用(apply_rotary_emb)
  • 功能:将预计算的 cos/sin 矩阵应用到 Q/K 张量,实现相对位置编码;

  • 步骤:先将 Q/K 的最后一维(head_dim)拆分为两半,通过三角函数计算旋转后的值,再拼接回原维度。

  1. 键值头重复(repeat_kv)
  • 功能:在 MQA 模式下,将 KV Head 的输出重复n_head//n_kv_head次,匹配 Q Head 的数量;

  • 实现:通过张量扩展(expand)+ 重塑(reshape)完成,避免冗余计算。

3.2 核心模块

3.2.1 因果自注意力层(CausalSelfAttention)

模型的注意力核心,支持 MQA 与 KV 缓存,确保 autoregressive 特性(当前 token 仅关注历史 token):

  • 输入输出:输入为隐藏状态(x)、cos/sin 矩阵(cos_sin)、KV 缓存(kv_cache),输出为注意力层的特征映射;

  • 核心流程

  1. 投影:通过 c_q、c_k、c_v 将输入 x 分别投影为 Q、K、V 张量,并调整维度为(B, T, H, head_dim)

  2. 位置编码:对 Q、K 应用旋转编码,再分别做 QK 归一化;

  3. KV 缓存:若启用缓存,将当前 K、V 插入缓存,获取历史 + 当前的完整 KV 张量;

  4. MQA 适配:重复 KV Head 以匹配 Q Head 数量;

  5. 注意力计算:根据是否启用缓存 / 序列长度,选择因果掩码(is_causal=True)或自定义掩码,调用F.scaled_dot_product_attention完成计算;

  6. 输出投影:将注意力结果重塑后,通过 c_proj 投影回隐藏状态维度(n_embd)。

3.2.2 MLP 层(MLP)

注意力层之后的非线性变换模块,负责增强模型的特征表达能力:

  • 结构:两层线性层 + ReLU² 激活,维度变换为 n_embd → 4*n_embd → n_embd

  • 激活函数:先通过F.relu(x),再对结果平方(square()),替代传统 GELU,计算更简洁;

  • 无偏置:c_fc 和 c_proj 均不设 bias,减少参数并降低训练波动。

3.2.3 Transformer 块(Block)

模型的基本堆叠单元,由「注意力层 + MLP 层」构成,采用残差连接:

  • 流程:输入 x 先经过 norm 归一化,送入注意力层后与原 x 残差相加;再对结果做 norm 归一化,送入 MLP 层后再次残差相加;

  • 特点:归一化操作在进入子层前(Pre-Norm),避免梯度消失,提升训练稳定性。

3.2.4 嵌入与输出层

  • 词嵌入层(wte)nn.Embedding(vocab_size, n_embd),将 token ID 转换为高维向量,默认使用 bfloat16 精度,节省内存;

  • 语言模型头(lm_head)nn.Linear(n_embd, vocab_size),将 Transformer 输出的隐藏状态映射为词汇表概率分布,与 wte 权重解耦,且初始化时权重设为 0。

四、 模型整体流程

4.1 初始化流程(init + init_weights)

  1. 模块构建:在__init__中初始化词嵌入层(wte)、Transformer 块列表(h)、lm_head,并预计算旋转编码(cos/sin),存储为非持久化 buffer(不写入 checkpoint);

  2. 权重初始化

  • 线性层:采用正态分布初始化,标准差为 1/sqrt(fan_in) * min(1, sqrt(fan_out/fan_in))(参考论文 2310.17813);

  • 嵌入层:正态分布初始化(mean=0.0, std=1.0);

  • 特殊初始化:lm_head、注意力层 c_proj、MLP 层 c_proj 的权重初始化为 0,减少初始训练波动。

4.2 训练阶段前向传播(forward)

  1. 输入处理:接收 token ID(idx)和标签(targets),检查序列长度是否超过旋转编码缓存;

  2. 旋转编码截取:根据 KV 缓存位置(T0),截取当前序列所需的 cos/sin 片段;

  3. 词嵌入与归一化:idx 通过 wte 转换为嵌入向量,再经过 norm 归一化;

  4. Transformer 堆叠计算:将嵌入向量依次送入每个 Block,每个 Block 接收 cos_sin 和 kv_cache,输出更新后的隐藏状态;

  5. 输出与损失计算:隐藏状态经 norm 后,通过 lm_head 生成 logits,应用 softcap(15*tanh (logits/15))稳定值范围,最终计算交叉熵损失(targets 存在时)。

4.3 推理阶段生成(generate)

采用朴素 autoregressive 生成,支持温度采样(temperature)和 Top-K 采样:

  1. 输入准备:将初始 token 列表转换为张量,添加 batch 维度;

  2. 循环生成

  • 每次前向传播仅取最后一个 token 的 logits;

  • 若启用 Top-K,过滤掉概率低于 Top-K 最小值的 logits;

  • 温度采样:logits 除以温度后 softmax, multinomial 采样下一个 token;贪心采样(temperature=0):取 logits 最大值;

  • 拼接当前 token 与历史 token,重复至生成 max_tokens 个 token;

  1. 输出:逐个 yield 生成的 token(Python int 类型),支持流式输出。

五、 优化与工具链

5.1 参数优化(setup_optimizers)

采用分组优化策略,针对不同类型参数设置不同学习率:

  • AdamW 优化器:用于 lm_head(unembedding_lr)和 wte(embedding_lr),学习率按 1/sqrt(n_embd/768) 缩放(适配不同隐藏维度);

  • Muon 优化器:用于 Transformer 块中的线性层(matrix_lr),侧重矩阵参数的稳定更新;

  • 权重衰减(weight_decay)可配置,默认 0,减少过拟合。

5.2 混合精度支持

  • 词嵌入层(wte)和旋转编码(cos/sin)默认使用 bfloat16 精度,减少内存占用和计算耗时;

  • logits 计算时转换为 float32,确保损失计算精度。

六、 nano-GPT 的核心组件数学公式与代码映射

此部分内容基于前文模型架构框架,针对四大核心组件(旋转编码、RMS 归一化、因果自注意力层、MLP 层),我们详细拆解其数学原理并给出对于的代码实现。

6.1 旋转编码(Rotary Embeddings)

旋转编码是模型的位置编码方案,通过三角函数对查询(Q)和键(K)进行旋转,注入相对位置信息,核心分为 “预计算 cos/sin 矩阵” 和 “应用旋转编码” 两步。

6.1.1 数学原理与公式

6.1.1.1 预计算 cos/sin 矩阵

旋转编码的核心是为每个位置t(序列中的第t个 token)和每个通道i(头维度head_dim的第i个维度)计算旋转频率,再生成对应的cossin值。

  • 步骤 1:计算逆频率(inv_freq)

    逆频率决定了不同通道的旋转周期,公式为:

    \text{inv\_freq}[i] = \frac{1}{\text{base}^{(2i/\text{head\_dim})}}

    其中:

    • i:通道索引(取值范围[0, 1, ..., head_dim//2 - 1],因旋转按 “两个通道一组” 进行);

    • base:基础周期参数(代码中默认10000);

    • head_dim:单个注意力头的维度(由n_embd // n_head计算得出)。

  • 步骤 2:计算频率矩阵(freqs)

    频率矩阵是位置与通道的外积,公式为:

    \text{freqs}[t, i] = t \times \text{inv\_freq}[i]

    其中:

    • t:位置索引(取值范围[0, 1, ..., seq_len - 1]seq_len为预计算的序列长度,代码中为rotary_seq_len);

    • 矩阵形状:(seq_len, head_dim//2)(因每个通道组对应一个频率)。

  • 步骤 3:生成 cos/sin 矩阵

    对频率矩阵分别应用余弦和正弦函数,得到最终用于旋转的矩阵:

    \cos_{mat}[t, i] = \cos(\text{freqs}[t, i])

    \sin_{mat}[t, i] = \sin(\text{freqs}[t, i])

    最终扩展维度以适配 Q/K 张量形状:

    \cos = \cos_{mat} \rightarrow \text{reshape}(1, \text{seq\_len}, 1, \text{head\_dim//2}) \rightarrow \text{repeat}(1, 1, 1, 2)

    \sin = \sin_{mat} \rightarrow \text{reshape}(1, \text{seq\_len}, 1, \text{head\_dim//2}) \rightarrow \text{repeat}(1, 1, 1, 2)

    (注:repeat(1,1,1,2)是为了让每个通道组的cos/sin值覆盖两个通道,匹配head_dim维度)

6.1.1.2 应用旋转编码到 Q/K

对 Q/K 张量的每个位置 - 通道对进行旋转,公式基于复数旋转原理(实部和虚部分别对应两个通道):

设 Q 的某两个通道为q1(第2i维)和q2(第2i+1维),旋转后为q1'q2'

q1' = q1 \times \cos_{mat}[t, i] + q2 \times \sin_{mat}[t, i]

q2' = -q1 \times \sin_{mat}[t, i] + q2 \times \cos_{mat}[t, i]

K 的旋转公式与 Q 完全一致,仅输入张量替换为 K。

6.1.2 代码实现

6.1.2.1 预计算 cos/sin 矩阵(对应_precompute_rotary_embeddings函数)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # 1. 计算逆频率(对应公式:inv_freq[i] = 1/base^(2i/head_dim)) channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) # i的范围:0,2,...head_dim-2(共head_dim//2个) inv_freq = 1.0 / (base ** (channel_range / head_dim)) # 按公式计算逆频率,形状:(head_dim//2,) # 2. 计算频率矩阵(对应公式:freqs[t,i] = t * inv_freq[i]) t = torch.arange(seq_len, dtype=torch.float32, device=device) # t的范围:0,1,...seq_len-1 freqs = torch.outer(t, inv_freq) # 外积计算,形状:(seq_len, head_dim//2) # 3. 生成cos/sin矩阵并扩展维度(对应公式:cos_mat = cos(freqs),sin_mat = sin(freqs)) cos, sin = freqs.cos(), freqs.sin() # 形状均为:(seq_len, head_dim//2) cos, sin = cos.bfloat16(), sin.bfloat16() # 转为bfloat16节省内存 # 扩展维度:(seq_len, head_dim//2) → (1, seq_len, 1, head_dim//2) → 最终匹配Q/K的(1, T, 1, head_dim) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # 增加batch维和head维 return cos, sin
6.1.2.2 应用旋转编码到 Q/K(对应apply_rotary_emb函数)
def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # Q/K张量形状:(B, T, H, head_dim),B=批次,T=序列长,H=头数 d = x.shape[3] // 2 # 每个旋转组的通道数(即head_dim//2) # 拆分Q/K为两个通道组(对应公式中q1和q2的分离) x1, x2 = x[..., :d], x[..., d:] # x1: (B,T,H,d),x2: (B,T,H,d) # 应用旋转公式(对应q1'和q2'的计算) y1 = x1 * cos + x2 * sin # (B,T,H,d),对应q1' = q1*cos + q2*sin y2 = x1 * (-sin) + x2 * cos # (B,T,H,d),对应q2' = -q1*sin + q2*cos # 拼接回原维度(恢复head_dim) out = torch.cat([y1, y2], 3) # (B,T,H,head_dim) out = out.to(x.dtype) # 确保输出与输入 dtype 一致 return out

6.2 RMS 归一化(RMS Norm)

RMS 归一化是模型的归一化方案,无学习参数,通过对张量的 “均方根” 进行归一化,稳定训练过程。

6.2.1 数学原理与公式

RMS 归一化针对张量的最后一维(特征维度,如head_dimn_embd)计算,公式为:

\text{norm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}}

其中:

  • x:输入张量(如 Q/K 张量、隐藏状态张量);

  • mean(x^2):对 x 的最后一维计算元素平方的均值;

  • ε:数值稳定性参数(PyTorch 的F.rms_norm中默认1e-6,代码中未显式指定,使用默认值);

  • 核心特点:无γ(缩放因子)和β(偏移因子),完全基于输入数据的统计信息,减少参数冗余。

6.2.2 代码实现

对应norm函数,直接调用 PyTorch 原生函数实现公式:

def norm(x): # 纯函数式RMS归一化(对应公式:x / sqrt(mean(x²) + ε)) # x.ndim可灵活适配:如Q/K张量(4维)、隐藏状态张量(3维) # (x.size(-1),):指定对最后一维(特征维度)进行归一化,匹配公式中的“特征维度均值” return F.rms_norm(x, (x.size(-1),))

6.3 因果自注意力层(CausalSelfAttention)

因果自注意力是模型的核心组件,负责捕捉序列内的依赖关系,支持多查询注意力(MQA)和 KV 缓存,核心流程包括 “投影→旋转→MQA 适配→注意力计算→输出投影”。

6.3.1 数学原理与公式

6.3.1.1 投影(Q/K/V 生成)

通过线性层将输入隐藏状态x投影为 Q、K、V 张量,公式为:

Q = x \times W_q

K = x \times W_k

V = x \times W_v

其中:

  • W_q:查询投影权重矩阵,形状(n_embd, n_head × head_dim)

  • W_k:键投影权重矩阵,形状(n_embd, n_kv_head × head_dim)

  • W_v:值投影权重矩阵,形状(n_embd, n_kv_head × head_dim)

  • 投影后维度调整:将Q(B, T, n_head×head_dim)reshape 为(B, T, n_head, head_dim),K/V 同理(K/V 为(B, T, n_kv_head, head_dim))。

6.3.1.2 MQA 适配(KV Head 重复)

MQA 模式下,KV Head 数量n_kv_head少于 Q Head 数量n_head,需将 KV 张量重复r次(r = n_head // n_kv_head),公式为:

K_{\text{MQA}} = K \rightarrow \text{repeat}(1, 1, r, 1)

V_{\text{MQA}} = V \rightarrow \text{repeat}(1, 1, r, 1)

其中:

  • 重复维度为 “头数维度”(第 2 维,对应代码中的dim=1);

  • 目的:让每个 Q Head 都能对应到相同的 KV Head,减少计算量。

6.3.1.3 注意力分数与输出计算

注意力分数计算基于缩放点积注意力,同时加入因果掩码(确保当前 token 仅关注历史 token),公式为:

  1. 注意力分数

    \text{score} = \frac{Q \times K_{\text{MQA}}^T}{\sqrt{\text{head\_dim}}}

    $K_{\text{MQA}}^T$表示对 K 的 “序列长度维度” 转置,即(B, H, T, head_dim) → (B, H, head_dim, T)

  2. 因果掩码

    score矩阵应用下三角掩码(下三角及对角线为 1,上三角为 -∞),公式为:

    \text{score}_{\text{masked}} = \text{score} + M

    其中M为掩码矩阵,M[t1, t2] = 0(若t1 ≥ t2,即当前 tokent1可关注历史 tokent2),否则M[t1, t2] = -∞

  3. 注意力权重与输出

    \text{attn\_weight} = \text{softmax}(\text{score}_{\text{masked}})

    \text{attn\_out} = \text{attn\_weight} \times V_{\text{MQA}}

  4. 最终投影

    将多个头的输出拼接后,通过线性层投影回隐藏状态维度:

    \text{out} = \text{concat}(\text{attn\_out}) \times W_{\text{proj}}

    其中W_proj为输出投影权重矩阵,形状(n_head×head_dim, n_embd)

6.3.2 代码实现

对应CausalSelfAttention类的forward方法,按 “投影→旋转→MQA 适配→注意力计算→输出投影” 分步映射:

def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() # B=批次,T=序列长,C=n_embd head_dim = self.head_dim # 单个头维度:n_embd // n_head # 1. 投影(对应公式Q=x×W_q,K=x×W_k,V=x×W_v) q = self.c_q(x).view(B, T, self.n_head, head_dim) # (B,T,n_head,head_dim):W_q对应self.c_q权重 k = self.c_k(x).view(B, T, self.n_kv_head, head_dim) # (B,T,n_kv_head,head_dim):W_k对应self.c_k权重 v = self.c_v(x).view(B, T, self.n_kv_head, head_dim) # (B,T,n_kv_head,head_dim):W_v对应self.c_v权重 # 2. 旋转编码与QK归一化(对应旋转公式+RMS公式) cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # 应用旋转编码 q, k = norm(q), norm(k) # QK归一化(对应RMS公式) # 调整维度:(B,T,H,D) → (B,H,T,D)(为注意力计算做准备,头维度提前) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # 3. KV缓存(补充历史KV,不影响数学公式,仅工程优化) if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) Tq = q.size(2) # 当前查询序列长 Tk = k.size(2) # 历史+当前的键值序列长 # 4. MQA适配(对应KV重复公式:K_MQA = K.repeat(1,r,1,1)) nrep = self.n_head // self.n_kv_head # r = n_head / n_kv_head k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) # K/V重复后形状:(B,n_head,Tk,head_dim) # 5. 注意力计算(对应score、mask、softmax、attn_out公式) if kv_cache is None or Tq == Tk: # 训练/全序列推理:因果掩码(is_causal=True),对应score_masked公式 y = F.scaled_dot_product_attention(q, k, v, is_causal=True) elif Tq == 1: # 单token推理:无掩码(仅关注所有历史KV) y = F.scaled_dot_product_attention(q, k, v, is_causal=False) else: # 多token推理:自定义因果掩码(前缀全关注+块内因果) attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) prefix_len = Tk - Tq attn_mask[:, :prefix_len] = True # 前缀(历史)全关注 attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), device=q.device)) # 块内因果 y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # 应用自定义掩码 # 6. 输出投影(对应out=concat(attn_out)×W_proj公式) y = y.transpose(1, 2).contiguous().view(B, T, -1) # 拼接头输出:(B,H,T,D)→(B,T,n_head×D) y = self.c_proj(y) # 线性投影:W_proj对应self.c_proj权重,输出形状(B,T,n_embd) return y

6.4 MLP 层(含 ReLU²激活)

MLP 层负责对注意力层的输出进行非线性变换,增强模型的特征表达能力,核心是 “线性扩展→ReLU²激活→线性压缩” 三步。

6.4.1 数学原理与公式

MLP 层的维度变换为n_embd → 4×n_embd → n_embd,公式如下:

  1. 线性扩展:将输入隐藏状态映射到更高维度(4 倍n_embd):

    x_1 = x \times W_{\text{fc}}

    其中W_fc为扩展权重矩阵,形状(n_embd, 4×n_embd)

  2. ReLU²激活:先应用 ReLU 函数,再对结果平方,公式为:

    x_2 = \text{ReLU}(x_1)^2 = \max(x_1, 0)^2

    (注:ReLU² 替代传统 GELU,计算更简洁,同时保留非线性表达能力)

  3. 线性压缩:将高维特征映射回n_embd维度:

    \text{out} = x_2 \times W_{\text{proj}}

    其中W_proj为压缩权重矩阵,形状(4×n_embd, n_embd)

6.4.2 代码实现

对应MLP类的forward方法,直接映射上述三步公式:

def forward(self, x): # 1. 线性扩展(对应公式x1 = x×W_fc) x = self.c_fc(x) # W_fc对应self.c_fc权重,输入(B,T,n_embd)→输出(B,T,4×n_embd) # 2. ReLU²激活(对应公式x2 = max(x1,0)^2) x = F.relu(x).square() # F.relu(x)是max(x,0),.square()是平方操作 # 3. 线性压缩(对应公式out = x2×W_proj) x = self.c_proj(x) # W_proj对应self.c_proj权重,输入(B,T,4×n_embd)→输出(B,T,n_embd) return x

七、整体代码块详细讲解

接下来我们对于nanoChat的gpt.py代码进行分块详细解释【按代码逻辑结构展开】:

7.1 头部注释与依赖导入

""" GPT model (rewrite, a lot simpler) Notable features: - rotary embeddings (and no positional embeddings) # 旋转位置嵌入(无传统位置嵌入) - QK norm # Q和K的归一化 - untied weights for token embedding and lm_head # 词嵌入与输出头权重不共享 - relu^2 activation in MLP # MLP使用relu平方激活 - norm after token embedding # 词嵌入后做归一化 - no learnable params in rmsnorm # RMSNorm无学习参数 - no bias in linear layers # 线性层无偏置 - Multi-Query Attention (MQA) support for more efficient inference # 支持MQA提升推理效率 """ import math from functools import partial from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW
  • 注释内容:概述模型核心特性,区别于传统GPT的设计(如旋转嵌入替代位置嵌入、MQA等)。
  • 依赖导入:包含PyTorch基础模块、数据类工具,以及自定义的分布式工具(get_dist_info)、优化器(MuonDistAdamW)等。

7.2 模型配置类(GPTConfig

@dataclass class GPTConfig: sequence_len: int = 1024 # 最大序列长度 vocab_size: int = 50304 # 词汇表大小 n_layer: int = 12 # Transformer层数 n_head: int = 6 # 查询头数量(Q头) n_kv_head: int = 6 # 键值头数量(K/V头,用于MQA/GQA) n_embd: int = 768 # 嵌入维度
  • 功能:用dataclass定义模型超参数,集中管理配置。
  • 关键参数n_headn_kv_head分离,支持Multi-Query Attention(MQA,n_kv_head=1)或Grouped-Query Attention(GQA,1 < n_kv_head < n_head),平衡性能与效率。

7.3 工具函数

7.3.1 RMS归一化(norm

def norm(x): # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),))
  • 功能:无学习参数的RMS归一化(区别于LayerNorm,无缩放/偏移参数),作用是稳定训练,归一化最后一个维度(特征维度)。

7.3.2 旋转位置嵌入(apply_rotary_emb

def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention: (B, T, H, D) d = x.shape[3] // 2 # 旋转维度的一半 x1, x2 = x[..., :d], x[..., d:] # 分割最后一维为两部分 y1 = x1 * cos + x2 * sin # 旋转第一部分 y2 = x1 * (-sin) + x2 * cos # 旋转第二部分 out = torch.cat([y1, y2], 3) # 拼接回原维度 out = out.to(x.dtype) # 保证输出与输入 dtype 一致 return out
  • 功能:实现旋转位置嵌入(RoPE),为注意力添加相对位置信息。
  • 原理:将输入特征分割为两半,通过三角函数(cos/sin)对两部分进行旋转组合,使模型感知序列中token的相对位置。

7.4 因果自注意力模块(CausalSelfAttention

class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx # 层索引(用于KV缓存) self.n_head = config.n_head self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head # 每个头的维度 assert self.n_embd % self.n_head == 0 # 嵌入维度需被头数整除 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 # KV头需为Q头的约数(GQA/MQA) # 线性投影层(无偏置) self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) # Q投影 self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) # K投影 self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) # V投影 self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) # 注意力输出投影 def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() # B: batch size, T: seq len, C: embedding dim # 投影得到Q、K、V q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # (B, T, n_head, head_dim) k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) # (B, T, n_kv_head, head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) # (B, T, n_kv_head, head_dim) # 应用旋转嵌入到Q和K cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK归一化(增强稳定性) # 转置维度:(B, T, H, D) → (B, H, T, D)(头维度提前,便于注意力计算) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # 处理KV缓存(推理时复用历史K/V,减少计算) if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) # 将当前K/V插入缓存,获取完整K/V Tq = q.size(2) # 当前查询序列长度 Tk = k.size(2) # 总键值序列长度(缓存+当前) # 注意力计算(分场景处理因果掩码) enable_gqa = self.n_head != self.n_kv_head # 是否启用GQA(Q头与KV头数量不同) if kv_cache is None or Tq == Tk: # 训练模式(无缓存)或推理时序列长度匹配:直接用因果掩码 y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) elif Tq == 1: # 推理时单token生成:查询需关注所有历史键值(无因果掩码) y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) else: # 推理时多token生成:前缀部分可全关注,当前块内用因果掩码 attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) prefix_len = Tk - Tq # 历史前缀长度 if prefix_len > 0: attn_mask[:, :prefix_len] = True # 允许关注前缀 attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) # 块内因果掩码 y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) # 拼接头并投影回残差流 y = y.transpose(1, 2).contiguous().view(B, T, -1) # (B, H, T, D) → (B, T, H*D) = (B, T, C) y = self.c_proj(y) return y
  • 核心功能:实现因果自注意力(仅关注前文),支持MQA/GQA和KV缓存(推理优化)。
  • 关键细节
    • Q、K、V通过线性层投影,维度适配多头结构。
    • 旋转嵌入+QK归一化:增强位置感知和训练稳定性。
    • KV缓存:推理时复用历史键值,减少重复计算。
    • 分场景注意力掩码:训练用完整因果掩码,推理单token/多token时动态调整掩码。

7.5 多层感知器(MLP

class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) # 升维 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) # 降维回残差流 def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() # relu^2激活(比单纯relu更平滑) x = self.c_proj(x) return x
  • 功能:Transformer中的前馈网络,增强模型非线性表达能力。
  • 激活函数relu(x)^2(平方ReLU),相比标准ReLU更平滑,可能提升训练稳定性。

7.6 Transformer块(Block

class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) # 注意力子模块 self.mlp = MLP(config) # MLP子模块 def forward(self, x, cos_sin, kv_cache): # 残差连接:输入 + 注意力输出(注意力前先归一化) x = x + self.attn(norm(x), cos_sin, kv_cache) # 残差连接:输入 + MLP输出(MLP前先归一化) x = x + self.mlp(norm(x)) return x
  • 功能:Transformer的基本单元,串联注意力和MLP,均使用残差连接+前置归一化(Norm before Attention/MLP)。

7.7 GPT主模型(GPT

7.7.1 初始化与参数初始化

class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict({ "wte": nn.Embedding(config.vocab_size, config.n_embd), # 词嵌入表 "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), # Transformer块列表 }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # 输出头(预测下一个token) # 预计算旋转嵌入(缓存足够长的序列,避免动态计算) self.rotary_seq_len = config.sequence_len * 10 # 超量预计算(10倍序列长度) head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # 不保存到checkpoint self.register_buffer("sin", sin, persistent=False) def init_weights(self): self.apply(self._init_weights) # 特殊初始化:输出头和投影层置零 torch.nn.init.zeros_(self.lm_head.weight) for block in self.transformer.h: torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) # 重新初始化旋转嵌入(确保设备正确) head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin # 词嵌入转bf16(节省内存) if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) def _init_weights(self, module): if isinstance(module, nn.Linear): # 线性层初始化:根据输入维度调整标准差(参考论文https://arxiv.org/pdf/2310.17813) fan_out = module.weight.size(0) fan_in = module.weight.size(1) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) torch.nn.init.normal_(module.weight, mean=0.0, std=std) elif isinstance(module, nn.Embedding): # 嵌入层初始化:正态分布 torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
  • 核心结构:包含词嵌入(wte)、Transformer块列表(h)、输出头(lm_head)。
  • 旋转嵌入预计算_precompute_rotary_embeddings提前计算cos/sin值,避免推理时重复计算,提升效率。
  • 参数初始化:线性层采用自适应标准差初始化,输出头和投影层置零(可能为了稳定训练初期)。

7.7.2 旋转嵌入预计算(_precompute_rotary_embeddings

def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): if device is None: device = self.transformer.wte.weight.device # 自动获取设备 channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) # 步长2取通道(旋转对) inv_freq = 1.0 / (base ** (channel_range / head_dim)) # 频率倒数(base=10000为默认) t = torch.arange(seq_len, dtype=torch.float32, device=device) # 时间步 freqs = torch.outer(t, inv_freq) # 外积计算 (seq_len, head_dim/2) cos, sin = freqs.cos(), freqs.sin() # 计算cos和sin cos, sin = cos.bfloat16(), sin.bfloat16() # 转bf16节省内存 # 增加batch和头维度,便于广播:(1, seq_len, 1, head_dim/2*2) cos, sin = cos[None, :, None, :], sin[None, :, None, :] return cos, sin
  • 功能:预计算不同时间步和通道的旋转嵌入参数(cos/sin),用于apply_rotary_emb函数。

7.7.3 前向传播(forward

def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() # idx: (B, T) 输入token序列 # 获取当前序列的旋转嵌入(考虑KV缓存的偏移) assert T <= self.cos.size(1), f"序列长度超过旋转嵌入缓存:{T} > {self.cos.size(1)}" T0 = 0 if kv_cache is None else kv_cache.get_pos() # 缓存偏移量(推理时从历史位置开始) cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # 截取对应长度的cos/sin # Transformer前向 x = self.transformer.wte(idx) # 词嵌入:(B, T) → (B, T, n_embd) x = norm(x) # 嵌入后归一化 for block in self.transformer.h: x = block(x, cos_sin, kv_cache) # 经过所有Transformer块 x = norm(x) # 最终归一化 # 计算输出(训练时返回损失,推理时返回logits) softcap = 15 # logits截断系数(稳定训练) if targets is not None: # 训练模式:计算交叉熵损失 logits = self.lm_head(x) # (B, T, vocab_size) logits = softcap * torch.tanh(logits / softcap) # 截断logits范围(避免过大值) logits = logits.float() # 用fp32计算损失(精度更高) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), # 展平为(B*T, vocab_size) targets.view(-1), # 展平为(B*T,) ignore_index=-1, # 忽略-1标签 reduction=loss_reduction ) return loss else: # 推理模式:返回logits logits = self.lm_head(x) logits = softcap * torch.tanh(logits / softcap) return logits
  • 功能:实现模型前向传播,支持训练(计算损失)和推理(返回logits)。
  • 关键细节
    • 旋转嵌入动态截取:根据KV缓存偏移量获取当前序列对应的cos/sin。
    • logits截断(softcap):通过tanh限制logits范围,避免梯度爆炸。

7.7.4 优化器设置(setup_optimizers

def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # 获取分布式信息 # 参数分组:矩阵参数(Transformer块)、嵌入参数、输出头参数 matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) # 学习率按模型维度缩放(∝1/√dmodel,适配不同嵌入维度) dmodel_lr_scale = (model_dim / 768) ** -0.5 if rank == 0: print(f"学习率缩放系数:∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") # AdamW优化器(用于嵌入和输出头) adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), ] adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) # Muon优化器(用于Transformer块的线性层,可能更适合大模型) muon_kwargs = dict(lr=matrix_lr, momentum=0.95) MuonFactory = DistMuon if ddp else Muon muon_optimizer = MuonFactory(matrix_params,** muon_kwargs) # 保存初始学习率 for opt in [adamw_optimizer, muon_optimizer]: for group in opt.param_groups: group["initial_lr"] = group["lr"] return [adamw_optimizer, muon_optimizer]
  • 功能:为不同参数组设置优化器(AdamW用于嵌入和输出头,Muon用于Transformer块),并按模型维度动态调整学习率。
  • 设计思路:不同参数(嵌入、输出头、Transformer矩阵)的优化特性不同,分组优化可提升训练效果。

7.7.5 生成函数(generate

@torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """自回归生成:从输入tokens开始,生成最多max_tokens个token""" assert isinstance(tokens, list) device = self.get_device() rng = None if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) # 随机数生成器(固定种子确保可复现) ids = torch.tensor([tokens], dtype=torch.long, device=device) # 转为tensor并添加batch维度 for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size) logits = logits[:, -1, :] # 取最后一个token的logits:(B, vocab_size) # Top-K过滤 if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # 过滤掉top_k外的token # 采样或贪心选择 if temperature > 0: logits = logits / temperature # 温度调整(温度越高随机性越强) probs = F.softmax(logits, dim=-1) next_ids = torch.multinomial(probs, num_samples=1, generator=rng) # 多项式采样 else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) # 贪心选择 ids = torch.cat((ids, next_ids), dim=1) # 拼接新token yield next_ids.item() # 返回生成的token
  • 功能:实现自回归文本生成,支持温度采样和Top-K过滤。
  • 关键细节
    • 推理模式(@torch.inference_mode()):禁用梯度计算,提升效率。
    • 动态扩展输入序列:每次生成一个新token并拼接到输入中,循环直至达到max_tokens

7.8 'gpt.py'代码总结

该代码实现了一个简化但高效的GPT模型,核心特点包括:

  • 旋转位置嵌入替代传统位置嵌入,增强相对位置感知。
  • 支持MQA/GQA,平衡注意力性能与计算效率。
  • 无偏置线性层、无参数RMSNorm,简化模型并提升稳定性。
  • KV缓存优化推理效率,分组优化器适配不同参数特性。
  • 提供完整的训练(损失计算)和推理(文本生成)接口。

八、nanoGPT模型架构重要知识详解与延展

代码实现的nanoGPT模型,本质是优化后的Transformer解码器堆叠结构,我们围绕“效率、稳定性、可扩展性”三大核心目标,结合现代大模型训练的关键技术(如RoPE、MQA、RMSNorm),讲解其中的一些关键知识点。

8.1 基础架构:Transformer解码器的“简化与优化”

GPT类模型的核心是Transformer解码器的重复堆叠,标准Transformer解码器包含“多头因果自注意力”“前馈网络(MLP)”“残差连接”“归一化”四大组件。代码中的Block类正是这一结构的实现,但做了关键优化,先明确基础框架再拆解细节:

8.1.1 标准Transformer解码器回顾

标准Transformer解码器(如GPT-2)的单块结构为:

输入 → LayerNorm → 多头因果自注意力 → 残差连接 → LayerNorm → MLP → 残差连接 → 输出

Pre-Norm(归一化在前) 设计,区别于早期Transformer的Post-Norm(归一化在后)。

8.1.2 代码中的Block优化:极致简化与效率

代码中的Block完全遵循Pre-Norm,但做了两处关键简化:

class Block(nn.Module): def forward(self, x, cos_sin, kv_cache): x = x + self.attn(norm(x), cos_sin, kv_cache) # Norm→Attention→残差 x = x + self.mlp(norm(x)) # Norm→MLP→残差 return x
8.1.3 核心优化点:
  • 无偏置线性层:所有Linear层(如c_qc_fc)均设bias=False
    原理:Transformer中残差连接和归一化已能稳定训练,偏置项对性能提升有限,却会增加1/3的参数量(如nn.Linear(768,768)无偏置时参数768²=589824,有偏置时多768个)。
    延展:现代大模型(如LLaMA、Falcon)均移除线性层偏置,成为行业默认实践。

  • 无参数RMSNorm:用F.rms_norm替代带可学习参数(γ、β)的LayerNorm。
    原理:RMSNorm计算为 x / sqrt(mean(x²) + ε),仅依赖输入统计信息,无额外参数;LayerNorm则是 γ*(x-mean)/std + β,多2个可学习参数/层。
    延展:RMSNorm比LayerNorm快30%+(减少参数读取和计算),且在大模型中性能与LayerNorm持平,已成为高效模型的首选(如LLaMA 2、Mistral)。

8.2 位置编码:旋转嵌入(RoPE)的“相对位置革命”

传统GPT用可学习位置嵌入(如GPT-2)或正弦位置嵌入(如原始Transformer),代码则用旋转位置嵌入(Rotary Position Embedding, RoPE),彻底改变了位置信息的注入方式。

8.2.1 RoPE的核心原理:基于复数旋转的相对位置编码

RoPE的本质是通过三角函数旋转,让注意力分数仅依赖token的相对位置,而非绝对位置。

8.2.1.1 数学推导(对应代码apply_rotary_emb):
  1. 维度拆分:将Q/K的头维度head_dim拆分为两半(d = head_dim//2),每对维度(2i, 2i+1)视为复数的“实部+虚部”;
  2. 旋转计算:对第t个token的Q向量q_t,注入位置信息等价于乘以复数旋转因子 e^(iθ_t)θ_t为位置t的角度):
    q_t' = q_t × e^(iθ_t) = [q_t1*cosθ_t - q_t2*sinθ_t, q_t1*sinθ_t + q_t2*cosθ_t]
    代码中对应:
    y1 = x1 * cos + x2 * sin # 实部 y2 = x1 * (-sin) + x2 * cos # 虚部
  3. 相对位置体现:当计算token ij的注意力分数时,旋转因子的差值为 e^(i(θ_i - θ_j)) = e^(iθ_{i-j}),仅依赖相对位置i-j,完美捕捉相对依赖。
8.2.1.2 代码中的RoPE优化:预计算缓存

代码在__init__中预计算10倍序列长度的cos/sin矩阵(self.rotary_seq_len = config.sequence_len * 10),避免推理时实时计算:

cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False)
  • 优势:推理时仅需截取对应长度的cos/sin,减少30%+的计算耗时;
  • 延展:工业界常用“动态扩展缓存”(如当序列超过预计算长度时,动态补算后续cos/sin),代码中用“10倍超量预计算”简化实现,适合中小规模场景。

8.2.2 RoPE vs 其他位置编码:为何成为主流?

位置编码类型 优点 缺点 代表模型
可学习位置嵌入 适配任务能力强 长度外推差(超训练长度失效) GPT-2、GPT-3
正弦位置嵌入 长度外推好 相对位置捕捉弱 原始Transformer
RoPE 相对位置精准+外推好 需拆分维度(head_dim需为偶数) LLaMA、Qwen、Mistral
  • 延展:RoPE的“长度外推性”是关键优势——训练时用1024序列,推理时可处理2048甚至4096序列,而可学习位置嵌入超过训练长度后性能骤降(需用“位置插值”等技巧补救)。

8.3 注意力机制:MQA/GQA——平衡精度与推理效率

标准Transformer用多头注意力(MHA),每个Q头对应独立的K/V头,参数和计算量巨大。代码通过n_headn_kv_head分离,支持多查询注意力(MQA)分组查询注意力(GQA),大幅优化推理效率。

8.3.1 三种注意力机制的核心差异

假设模型n_head=12head_dim=64n_embd=768):

注意力类型 KV头数量(n_kv_head) KV参数量(每层) 推理速度 精度
MHA(标准) 12(与Q头相等) 12×64=768
MQA(多查询) 1(所有Q头共享) 1×64=64
GQA(分组查询) 4(3个Q头共享1个KV头) 4×64=256

8.3.2 代码中的实现逻辑(对应CausalSelfAttention):

  1. 参数判断:通过self.n_head != self.n_kv_head启用GQA(enable_gqa=True);
  2. 底层调用:依赖PyTorch 2.0+的F.scaled_dot_product_attention,原生支持MQA/GQA(通过enable_gqa参数),避免手动重复KV头(早期需用repeat_kv函数,代码中已简化);
  3. 因果掩码:分场景处理掩码(训练用全因果掩码,推理单token无掩码,推理多token用“前缀+因果”混合掩码),确保自回归特性。

8.3.3 为何MQA/GQA成为推理优化的核心?

  • 推理效率瓶颈:自回归生成时,注意力计算的耗时主要来自KV的存储和读取——MHA需存储12组KV,MQA仅需1组,内存占用降低12倍,带宽消耗减少,推理吞吐量提升3-5倍;
  • 精度权衡:MQA比MHA精度略降(约1-2%),但GQA通过分组(如4组KV头)可恢复99%以上的MHA精度,同时保持2-3倍的推理速度提升;
  • 延展:GPT-4、Mistral 7B用GQA,Claude用MQA,工业界已形成“大模型用GQA,轻量模型用MQA”的共识。

8.4 前馈网络(MLP):ReLU²——简化非线性激活

标准Transformer的MLP用GELU激活(如GPT-3、BERT),代码则用ReLU²激活F.relu(x).square()),在简化计算的同时保留非线性表达能力。

8.4.1 激活函数对比:效率与性能的权衡

激活函数 计算复杂度 梯度特性 代表模型
GELU 高(需erf) 平滑(可导性好) GPT-3、LLaMA 1
ReLU 0点不可导(死神经元) 早期CNN
ReLU² 低(ReLU+平方) 0点可导(更平滑) 代码模型、部分轻量LLM

8.4.2 ReLU²的优势(对应代码MLP):

  1. 计算快:ReLU和平方均为基础运算,无需调用复杂的erf函数(GELU的核心),GPU上可提速20%+;
  2. 梯度更平滑:ReLU在x=0处导数为0(不可导),ReLU²在x=0处导数为0(可导),避免“死神经元”问题;
  3. 表达能力足够:对中小模型(如代码中n_embd=768,n_layer=12),ReLU²的非线性能力足以匹配GELU,且过拟合风险更低。

延展:

  • 现代大模型(如LLaMA 2、Qwen)仍用GELU,因大模型需要更强的非线性表达;轻量模型(如n_embd<512)更适合ReLU²或Swish(简化版),平衡效率与性能。

8.5 其他关键设计:权重解耦、参数初始化与KV缓存

8.5.1 权重解耦(Untied Weights):词嵌入与输出头分离

传统GPT(如GPT-2)共享词嵌入层(wte)和输出头(lm_head)的权重(lm_head.weight = wte.weight.T),代码则完全解耦:

self.transformer = nn.ModuleDict({"wte": nn.Embedding(...)}) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

解耦的原因:

  • 优化目标不同:词嵌入层的目标是“将token映射到语义空间”(优化语义相似度),输出头的目标是“将语义空间映射到token概率”(优化分类损失),共享权重会限制两者的独立优化;
  • 精度提升:在大词汇表(如代码中vocab_size=50304)场景下,解耦可让lm_head针对性优化高频token的预测精度,实验表明可提升1-3%的困惑度(Perplexity)。

延展【权重贡献优势】:

  • 权重共享的唯一优势是减少参数量(如vocab_size=50304,n_embd=768,共享可减少50304×768≈3800万参数),但现代硬件已能承载这部分参数,解耦成为主流。

8.5.2 参数初始化:自适应标准差,避免梯度爆炸

代码中线性层的初始化参考论文《A Spectral Condition for Feature Learning》【2310.17813】,采用自适应标准差:

std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) torch.nn.init.normal_(module.weight, mean=0.0, std=std)

原理:

  • 传统初始化(Xavier/He)未考虑“输入输出维度比”:
    • fan_out > fan_in(如MLP的升维层:c_fc从768→3072,fan_out=3072,fan_in=768),sqrt(fan_out/fan_in)=2,标准差放大2倍,避免信号衰减;
    • fan_out ≤ fan_in(如降维层:c_proj从3072→768),min(1, ...)取1,标准差为1/√fan_in,避免信号放大导致梯度爆炸。

延展:

  • Transformer中存在大量“升维-降维”操作(MLP的768→3072→768),自适应初始化比传统方法更能稳定深层训练(如100层以上的模型)。

8.5.3 KV缓存:推理效率的“关键工程优化”

代码在推理时用kv_cache存储历史的K和V,避免每次生成新token时重新计算所有历史的QK注意力:

if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) # 插入当前KV到缓存

原理(以生成100个token为例):

  • 无KV缓存:每次生成第t个token,需计算前t个token的K和V,总计算量为1+2+...+100=5050次;
  • 有KV缓存:仅需计算第t个token的K和V,历史KV复用,总计算量为100次,效率提升50倍。

延展:KV缓存的内存挑战与优化

  • 内存占用:每个Transformer层需存储B×n_kv_head×T×head_dim的K和V(B=批次,T=生成长度),生成1024个token时,单批次768维模型的KV缓存约为1×6×1024×64≈4MB/层,12层约48MB,可控;
  • 优化方向:长序列生成(如4096+token)时,用“滑动窗口KV缓存”(仅保留最近的512个token的KV)或“KV量化”(将FP16的KV转为INT8/INT4),减少内存占用。

九、架构设计总结与行业延展

9.1 架构设计总结

代码中的nanoGPT,是现代高效LLM架构的“缩影”,其设计思路与行业主流模型(如LLaMA、Mistral)高度一致,核心可总结为:

  1. 简化组件:无偏置、无参数归一化、ReLU²激活,降低复杂度;
  2. 高效注意力:MQA/GQA+KV缓存,平衡精度与推理速度;
  3. 强位置感知:RoPE替代传统位置嵌入,提升外推性;
  4. 灵活优化:权重解耦、自适应初始化,适配不同场景。

9.2 行业延展:从“简化版”到“工业级LLM”的进化方向

  1. 模型并行:代码为单卡模型,工业级LLM(如GPT-4)用张量并行(TP)拆分注意力头,数据并行(DP)拆分批次,流水线并行(PP)拆分Transformer层;
  2. 量化训练/推理:代码用BF16训练,工业界用FP8量化(如NVIDIA的FP8 Transformer Engine),进一步降低内存和计算量;
  3. 注意力改进:加入“FlashAttention”(优化注意力计算的内存访问)或“多头注意力稀疏化”(如仅计算Top-50%的注意力分数),提升效率;
  4. 训练框架优化:代码用PyTorch原生函数,工业界用Megatron-LM、DeepSpeed等框架,支持万亿参数模型的分布式训练。

我们可以通过理解nanoChat该简化模型的架构,可快速掌握现代LLM的核心设计思想,为后续学习复杂模型(如Mistral、Qwen)打下基础。

补充内容:大语言模型架构训练核心原理全解析

大语言模型(LLM)的架构与训练本质是“基于Transformer架构的深层神经网络优化过程”,涉及模型结构设计、注意力机制、归一化与残差连接、优化器、学习率调度、工程优化六大核心模块。本部分补充内容将在前文的基础上,进一步将逐一拆解每个模块的算法原理,从“动机→数学推导→直观解释→实现细节”层层递进,让大家能更好理解其来龙去脉。

一、Transformer架构:从“序列建模痛点”到“全局依赖捕捉”

1.1 为什么需要Transformer?

传统序列建模(RNN/LSTM/GRU)存在两大痛点:

  1. 串行计算:需逐token处理,无法并行,训练效率低;
  2. 长距离依赖衰减:通过“门控机制”缓解但无法根治,长序列中梯度易消失。

Transformer的核心创新是用“自注意力机制”替代RNN的串行依赖建模,实现:

  • 并行计算(所有token同时处理);
  • 全局依赖捕捉(任意两个token的依赖可直接建模);
  • 可扩展性(堆叠深层网络,适配大模型)。

注:所以从某种角度来说,Transformer并不一定是最优的架构,只是我们现在找到的从效率角度来说“最好的架构”,包括为什么在大模型中舍弃编码器(Encoder),只保留解码器堆叠,也是从同参数量下的效果经济性角度考虑的。

1.2 Transformer整体架构(数学定义)

Transformer由编码器(Encoder)解码器(Decoder) 组成,LLM(如GPT系列)仅使用解码器堆叠(自回归生成),其结构定义为:

\text{TransformerDecoder}(X) = \text{Stack}\left[\text{DecoderBlock}_1, \text{DecoderBlock}_2, ..., \text{DecoderBlock}_L\right](X)

其中:

  • X \in \mathbb{R}^{B \times T \times D}:输入张量(B=批次大小,T=序列长度,D=嵌入维度);
  • L:解码器块数量(大模型L可达100+,如GPT-3 L=96);
  • 每个\text{DecoderBlock}包含:因果自注意力层前馈网络(MLP)层归一化残差连接四大组件。

1.3 解码器块数学推导(Pre-Norm设计)

现代LLM均采用Pre-Norm(归一化在前) 设计(区别于原始Transformer的Post-Norm),单个解码器块的计算流程为:

\begin{align*} X_1 &= \text{LayerNorm}(X) \quad \text{// 归一化} \\ X_{\text{attn}} &= \text{CausalSelfAttention}(X_1) \quad \text{// 因果自注意力} \\ X_2 &= X + X_{\text{attn}} \quad \text{// 残差连接} \\ X_3 &= \text{LayerNorm}(X_2) \quad \text{// 归一化} \\ X_{\text{mlp}} &= \text{MLP}(X_3) \quad \text{// 前馈网络} \\ X_{\text{out}} &= X_2 + X_{\text{mlp}} \quad \text{// 残差连接} \end{align*}

关键设计的动机:

  • Pre-Norm:归一化提前,避免深层网络中梯度爆炸(Post-Norm易在深层出现数值不稳定);
  • 残差连接:X + X_{\text{attn}} 直接传递梯度,缓解深层网络梯度消失。

1.4 实现细节(对应LLM代码)

class DecoderBlock(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.norm1 = LayerNorm(d_model) # Pre-Norm self.attn = CausalSelfAttention(d_model, n_head) self.norm2 = LayerNorm(d_model) self.mlp = MLP(d_model, d_ff) self.dropout = nn.Dropout(dropout) def forward(self, x, kv_cache=None): # 自注意力支路 x1 = self.norm1(x) attn_out = self.attn(x1, kv_cache) x = x + self.dropout(attn_out) # 残差+dropout # MLP支路 x2 = self.norm2(x) mlp_out = self.mlp(x2) x = x + self.dropout(mlp_out) # 残差+dropout return x

二、自注意力机制(Self-Attention):全局依赖的数学建模

自注意力是Transformer的核心,其本质是“通过token间的相似度计算权重,对全局信息加权求和”,实现全局依赖捕捉。

2.1 核心动机

对序列中每个token x_i,希望生成一个新的表示 y_i,该表示融合自身信息序列中其他token的相关信息

y_i = \sum_{j=1}^T \alpha_{i,j} \cdot x_j

其中 \alpha_{i,j} 是token i 对token j 的注意力权重(\sum_j \alpha_{i,j}=1)。

2.2 Scaled Dot-Product Attention(缩放点积注意力)

步骤1:QKV投影(将输入映射为查询、键、值)

为了计算注意力权重,需将输入 X 投影为三个向量:

  • 查询(Query)Q:当前token“要找什么信息”;
  • 键(Key)K:其他token“提供什么信息”;
  • 值(Value)V:其他token“具体的信息内容”。

投影公式:

Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V

其中:

  • W_Q, W_K, W_V \in \mathbb{R}^{D \times D_k}:可学习投影矩阵(D_k 为QKV的维度,LLM中 D_k = D/n_{\text{head}},即每个注意力头的维度);
  • Q, K, V \in \mathbb{R}^{B \times T \times D_k}

步骤2:注意力分数计算(相似度度量)

用“点积”衡量 Q_i(第i个token的查询)与 K_j(第j个token的键)的相似度:

\text{score}(i,j) = Q_i \cdot K_j^T
  • 点积的优势:计算高效(可通过矩阵乘法并行);
  • 问题:当 D_k 较大时,点积结果会过大,导致softmax后梯度消失(softmax在大数值下梯度趋近于0)。

步骤3:缩放(缓解梯度消失)

对分数除以 \sqrt{D_k},使分数的方差归一化到1附近:

\text{scaled\_score}(i,j) = \frac{Q_i \cdot K_j^T}{\sqrt{D_k}}
  • 数学证明:假设 Q_iK_j 是独立同分布的随机变量(均值0,方差1),则 Q_i \cdot K_j^T 的均值为0,方差为 D_k;除以 \sqrt{D_k} 后方差为1,避免数值溢出。

步骤4:因果掩码(Causal Mask,LLM核心)

LLM是自回归模型,生成第i个token时,只能依赖前i-1个token,需屏蔽后验信息(j > i 的token):

\text{masked\_score}(i,j) = \begin{cases} \text{scaled\_score}(i,j) & j \leq i \\ -\infty & j > i \end{cases}
  • 掩码矩阵 M \in \mathbb{R}^{T \times T}:下三角为0,上三角为 -\infty,与score相加实现屏蔽。

步骤5:Softmax归一化(得到注意力权重)

将掩码后的分数归一化为概率分布(权重和为1):

\alpha_{i,j} = \text{softmax}\left(\text{masked\_score}(i,j)\right) = \frac{e^{\text{masked\_score}(i,j)}}{\sum_{k=1}^T e^{\text{masked\_score}(i,k)}}

步骤6:加权求和(得到注意力输出)

用注意力权重对 V 加权求和,得到每个token的最终表示:

\text{Attention}(Q,K,V) = \alpha \cdot V

其中 \alpha \in \mathbb{R}^{B \times T \times T} 是注意力权重矩阵。

完整公式(矩阵形式,并行计算)

\text{ScaledDotProductAttention}(Q,K,V,M) = \text{softmax}\left( \frac{QK^T}{\sqrt{D_k}} + M \right) V

2.3 多头注意力(Multi-Head Attention)

动机:单一注意力头只能捕捉一种类型的依赖(如语法依赖),多头可并行捕捉多种依赖(语法、语义、逻辑等)。

步骤1:多头拆分与投影

W_Q, W_K, W_V 拆分为 n_{\text{head}} 个独立的投影矩阵:

W_Q = \left[W_{Q1}, W_{Q2}, ..., W_{Qn_{\text{head}}}\right]^T, \quad W_K, W_V \text{ 同理}

每个头的QKV投影:

Q_h = X \cdot W_{Qh}, \quad K_h = X \cdot W_{Kh}, \quad V_h = X \cdot W_{Vh} \quad (h=1..n_{\text{head}})

其中 W_{Qh} \in \mathbb{R}^{D \times D_k}D_k = D/n_{\text{head}}(确保总维度不变:n_{\text{head}} \times D_k = D)。

步骤2:多头并行计算

每个头独立计算缩放点积注意力:

\text{Head}_h = \text{ScaledDotProductAttention}(Q_h, K_h, V_h, M)

步骤3:多头拼接与输出投影

将所有头的输出拼接,再通过一个线性层投影回 D 维度:

\text{MultiHead}(Q,K,V,M) = \left[\text{Head}_1, \text{Head}_2, ..., \text{Head}_{n_{\text{head}}}\right] \cdot W_O

其中 W_O \in \mathbb{R}^{D \times D} 是输出投影矩阵。

数学优势:

多头注意力的输出可以表示为:

\text{MultiHead}(X) = \sum_{h=1}^{n_{\text{head}}} \text{Attention}(X W_{Qh}, X W_{Kh}, X W_{Vh}) W_{Oh}

本质是“多个不同的注意力机制的集成”,提升模型的表达能力。

2.4 实现细节(PyTorch)

class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_head): super().__init__() self.n_head = n_head self.d_k = d_model // n_head # 每个头的维度 # QKV投影层(无偏置) self.w_q = nn.Linear(d_model, d_model, bias=False) self.w_k = nn.Linear(d_model, d_model, bias=False) self.w_v = nn.Linear(d_model, d_model, bias=False) # 输出投影层 self.w_o = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, D = x.shape # 1. QKV投影:(B,T,D) → (B,T,D) q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) # 2. 多头拆分:(B,T,D) → (B,n_head,T,d_k) q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2) v = v.view(B, T, self.n_head, self.d_k).transpose(1, 2) # 3. 因果掩码:(T,T),上三角为True(屏蔽) mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() # 4. 缩放点积注意力(PyTorch原生实现,高效) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=True ) # 5. 多头拼接:(B,n_head,T,d_k) → (B,T,D) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D) # 6. 输出投影 out = self.w_o(attn_out) return out

三、位置编码(Positional Encoding):给序列“注入顺序信息”

3.1 核心问题

自注意力机制是无序的QK^T 计算不依赖token的位置),但语言序列的顺序至关重要(如“我吃苹果”≠“苹果吃我”),因此需要手动注入位置信息,这就是为什么设计位置编码。

3.2 设计目标

位置编码 PE \in \mathbb{R}^{T \times D} 需满足:

  1. 对每个位置 tPE(t) 是唯一的(区分不同位置);
  2. 位置差的编码是固定的(PE(t+k) - PE(t)t 无关,捕捉相对位置);
  3. 可扩展到训练时未见过的长序列(长度外推性)。

3.3 正弦余弦位置编码(原始Transformer)

数学公式

对每个位置 t(0-based)和维度 i(0-based),位置编码为:

PE(t, 2i) = \sin\left( \frac{t}{10000^{2i/D}} \right)
PE(t, 2i+1) = \cos\left( \frac{t}{10000^{2i/D}} \right)
  • 偶数维度用正弦,奇数维度用余弦;
  • 10000 是超参数,控制位置编码的周期(维度越大,周期越长)。

为什么能捕捉相对位置?

假设两个位置 tt+k,其编码满足三角函数的和角公式:

\sin(t+k) = \sin t \cos k + \cos t \sin k
\cos(t+k) = \cos t \cos k - \sin t \sin k

PE(t+k) 可表示为 PE(t) 的线性组合,模型可通过学习权重捕捉相对位置 k

缺点

  • 长序列外推性差(训练时用 T=1024,推理时 T=2048 会出现未见过的编码),为了解决或者说缓解此类问题,ALiBi、YaRN等方法被提出。

3.4 旋转位置编码(RoPE,LLM主流)

RoPE是当前LLM的首选位置编码(如GPT-4、LLaMA、Qwen),核心是“通过旋转QK向量注入位置信息”,完美满足相对位置捕捉和长序列外推。

数学推导

  1. 复数表示:将Q/K的每个头维度 D_k 拆分为 D_k/2 对(实部+虚部),表示为复数:
    Q(t) = [Q(t,0) + iQ(t,1), Q(t,2) + iQ(t,3), ..., Q(t,D_k-2) + iQ(t,D_k-1)]
  2. 旋转注入:对复数Q/K乘以旋转因子 e^{-i\theta_t}\theta_t = t / 10000^{2i/D_k}),注入位置信息:
    Q_{\text{rot}}(t) = Q(t) \cdot e^{-i\theta_t}
    K_{\text{rot}}(t) = K(t) \cdot e^{-i\theta_t}
  3. 注意力分数的相对位置体现
    Q_{\text{rot}}(t_i) \cdot K_{\text{rot}}(t_j)^* = Q(t_i) \cdot K(t_j)^* \cdot e^{-i(\theta_{t_i} - \theta_{t_j})} = Q(t_i) \cdot K(t_j)^* \cdot e^{-i\theta_{t_i - t_j}}
    其中 * 表示共轭复数,注意力分数仅依赖相对位置 t_i - t_j

实数域实现(对应代码apply_rotary_emb

旋转因子 e^{-i\theta_t} = \cos\theta_t - i\sin\theta_t,对Q的实部和虚部进行旋转:

Q_{\text{rot}}(t, 2i) = Q(t,2i) \cos\theta_t - Q(t,2i+1) \sin\theta_t
Q_{\text{rot}}(t, 2i+1) = Q(t,2i) \sin\theta_t + Q(t,2i+1) \cos\theta_t

这正是代码中的实现逻辑:

def apply_rotary_emb(x, cos, sin): d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # 实部、虚部 y1 = x1 * cos + x2 * sin # 旋转后实部 y2 = -x1 * sin + x2 * cos # 旋转后虚部 return torch.cat([y1, y2], dim=-1)

优势

  • 完美捕捉相对位置;
  • 长序列外推性强(旋转因子仅依赖相对位置,与绝对长度无关);
  • 计算高效(可预计算cos/sin矩阵)。

3.5 实现细节

def precompute_rope(d_model, n_head, max_seq_len=2048): d_k = d_model // n_head # 计算频率:theta_i = 10000^(-2i/d_k) theta = 1.0 / (10000 ** (torch.arange(0, d_k, 2) / d_k)) # 计算位置t的theta_t = t * theta_i t = torch.arange(max_seq_len) freqs = torch.outer(t, theta) # (max_seq_len, d_k//2) # 生成cos和sin矩阵(扩展维度适配QK) cos = freqs.cos().unsqueeze(0).unsqueeze(1) # (1,1,max_seq_len,d_k//2) sin = freqs.sin().unsqueeze(0).unsqueeze(1) # 重复维度,匹配d_k(每个频率对应实部+虚部) cos = cos.repeat(1, 1, 1, 2) # (1,1,max_seq_len,d_k) sin = sin.repeat(1, 1, 1, 2) return cos, sin

四、层归一化(Layer Norm)与残差连接(Residual Connection):稳定深层训练的核心

4.1 层归一化(Layer Norm):缓解内部协变量偏移

核心问题:内部协变量偏移(Internal Covariate Shift)

深层网络训练时,每一层的输入分布会随着前一层参数的更新而变化,导致:

  1. 需用较小的学习率(避免参数震荡);
  2. 梯度消失(激活函数在极端分布下梯度趋近于0)。

层归一化的目标是“将每一层的输入归一化到固定分布(均值0,方差1)”,缓解分布漂移。

数学公式(带可学习参数的Layer Norm)

对输入张量 X \in \mathbb{R}^{B \times T \times D},在特征维度(D) 上归一化:

\mu = \frac{1}{D} \sum_{i=1}^D X_{:, :, i} \quad \text{// 特征维度均值}
\sigma^2 = \frac{1}{D} \sum_{i=1}^D (X_{:, :, i} - \mu)^2 \quad \text{// 特征维度方差}
\hat{X} = \frac{X - \mu}{\sqrt{\sigma^2 + \epsilon}} \quad \text{// 归一化(ε避免除零)}
\text{LayerNorm}(X) = \gamma \cdot \hat{X} + \beta \quad \text{// 缩放+偏移(可学习参数)}

其中 \gamma \in \mathbb{R}^D(缩放因子)、\beta \in \mathbb{R}^D(偏移因子)是可学习参数,允许模型恢复原始分布的表达能力(若 \gamma=\sqrt{\sigma^2+\epsilon}\beta=\mu,则归一化失效)。

无参数RMS Norm(LLM简化版)

LLM中常省略 \gamma\beta,用更简单的RMS Norm(Root Mean Square Norm):

\text{RMSNorm}(X) = \frac{X}{\sqrt{\frac{1}{D} \sum_{i=1}^D X_{:, :, i}^2 + \epsilon}}
  • 优势:减少2D个参数,计算更快;
  • 原理:残差连接和注意力机制已能稳定训练,可学习参数的增益有限。

与Batch Norm的区别(为什么Transformer用Layer Norm?)

特性 Layer Norm Batch Norm
归一化维度 特征维度(D) 批次维度(B)
批次依赖性 无(单样本即可计算) 有(依赖批次内样本)
序列长度适应性 好(T可任意变化) 差(T固定时效果好)
大模型场景适用性 强(并行训练+长序列) 弱(批次波动+长序列失效)

4.2 残差连接(Residual Connection):缓解梯度消失【深度学习中的信息传输高速路】

核心动机

深层网络中,梯度通过链式法则传递时会被不断相乘,若梯度绝对值小于1,深层梯度会趋近于0(梯度消失)。残差连接通过“直接传递原始输入”,让梯度可直接回传。

数学公式

对网络层 F(如注意力层、MLP),残差连接定义为:

\text{Residual}(X) = X + F(X)
  • X 是原始输入,F(X) 是层的输出(需与 X 维度一致)。

梯度传递分析(链式法则)

假设损失为 L,对参数 \theta 的梯度为:

\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial \text{Residual}(X)} \cdot \left( \frac{\partial X}{\partial \theta} + \frac{\partial F(X)}{\partial \theta} \right)
  • 当深层网络中 \frac{\partial F(X)}{\partial \theta} 趋近于0时,\frac{\partial X}{\partial \theta} 仍能传递梯度,避免梯度消失;
  • 要求 F(X)X 维度一致(否则需用1×1卷积调整维度)。

实现细节(对应DecoderBlock)

# 残差连接+归一化(Pre-Norm) x = x + self.dropout(self.attn(self.norm1(x))) x = x + self.dropout(self.mlp(self.norm2(x)))

五、优化器:从SGD到AdamW,大模型的梯度更新策略

优化器的核心是“根据梯度调整参数,最小化损失函数”,LLM的损失函数为交叉熵损失(预测下一个token的概率):

L = -\frac{1}{BT} \sum_{b=1}^B \sum_{t=1}^T \log P(y_{b,t} | X_{b,1:t})

其中 y_{b,t} 是第b个样本第t个token的真实标签,P 是模型预测的概率分布。

5.1 随机梯度下降(SGD):优化器基础

数学原理

参数 \theta 的更新规则:

\theta_{t+1} = \theta_t - \eta \cdot \nabla L(\theta_t)
  • \eta 是学习率(步长),\nabla L(\theta_t) 是损失对 \theta_t 的梯度。

缺点

  • 学习率固定,无法适配不同参数的梯度特性(如稀疏梯度、密集梯度);
  • 梯度震荡(高维空间中梯度方向易变化);
  • 收敛慢(需手动调整学习率)。

5.2 Adam:自适应学习率优化器

Adam(Adaptive Moment Estimation)融合了“动量(Momentum)”和“自适应学习率(RMSProp)”,是LLM训练的基础优化器。

核心思想

  1. 一阶矩估计(动量):平滑梯度方向,减少震荡;
  2. 二阶矩估计:根据梯度的方差调整学习率(梯度方差大则学习率小,反之则大);
  3. 偏差修正:初期一阶矩和二阶矩估计有偏差,需修正。

数学公式(更新规则)

给定梯度 \nabla L_t = \nabla L(\theta_t),迭代更新:

  1. 一阶矩(动量):

    m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot \nabla L_t
    • \beta_1 是动量系数(默认0.9),m_t 是梯度的指数移动平均(EMA)。
  2. 二阶矩(自适应学习率):

    v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot (\nabla L_t)^2
    • \beta_2 是二阶矩系数(默认0.999),v_t 是梯度平方的EMA。
  3. 偏差修正(初期修正):

    \hat{m}_t = \frac{m_t}{1 - \beta_1^t}
    \hat{v}_t = \frac{v_t}{1 - \beta_2^t}
    • t 是迭代步数,初期 \beta_1^t\beta_2^t 接近1,修正后避免低估梯度。
  4. 参数更新:

    \theta_{t+1} = \theta_t - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
    • \epsilon 是数值稳定性参数(默认1e-8),避免除零。

优势

  • 自适应学习率,无需手动调整;
  • 动量平滑梯度,收敛更快;
  • 适合稀疏梯度场景(LLM中大量参数梯度为0)。

5.3 AdamW:修复权重衰减的Adam变体

Adam的核心缺陷是“权重衰减(Weight Decay)”的实现错误,AdamW通过“将权重衰减与梯度更新分离”修复了这一问题,成为LLM的标准优化器。

权重衰减的作用

权重衰减本质是L2正则化,通过惩罚大权重,减少模型过拟合:

L_{\text{reg}} = L + \frac{\lambda}{2} \sum_{\theta} \theta^2

其中 \lambda 是权重衰减系数(默认0.01)。

Adam与AdamW的权重衰减区别

  • Adam的错误实现:将权重衰减融入梯度更新(等价于对梯度加惩罚):

    \theta_{t+1} = \theta_t - \eta \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \cdot \theta_t \right)
    • 问题:自适应学习率会缩放权重衰减项,导致惩罚力度不稳定。
  • AdamW的正确实现:权重衰减与梯度更新分离(直接惩罚权重):

    \theta_{t+1} = (1 - \eta \cdot \lambda) \cdot \theta_t - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
    • 优势:权重衰减力度不受自适应学习率影响,惩罚更稳定。

实现细节(PyTorch)

# AdamW优化器配置(LLM常用参数) optimizer = torch.optim.AdamW( model.parameters(), lr=2e-5, # 学习率 betas=(0.9, 0.95),# 一阶矩、二阶矩系数(LLM中β2常用0.95) weight_decay=0.01,# 权重衰减系数 eps=1e-8 # 数值稳定性参数 )

六、学习率调度策略:动态调整步长,优化收敛

学习率是优化器的核心超参数:

  • 学习率过大:参数震荡,无法收敛;
  • 学习率过小:收敛过慢,陷入局部最优。

LLM常用“Warmup + Cosine Annealing”组合调度策略,平衡“初期稳定”和“后期精细化调整”。

6.1 Warmup(热身调度)

核心动机

训练初期,模型参数随机初始化,梯度方差大,若用较大学习率会导致参数震荡。Warmup阶段逐步增大学习率,让模型适应梯度分布。

数学公式(线性Warmup)

设Warmup步数为 T_w,初始学习率为 \eta_0,目标学习率为 \eta_{\text{max}}

\eta(t) = \eta_0 + \frac{\eta_{\text{max}} - \eta_0}{T_w} \cdot t \quad (t \leq T_w)
  • t 是当前步数,t 从1到 T_w 时,学习率从 \eta_0 线性增长到 \eta_{\text{max}}
  • 常用 T_w = 10^4T_w = 5\% 总步数(大模型常用更长的Warmup)。

实现细节

def warmup_lr_scheduler(optimizer, warmup_steps): def lr_lambda(step): if step < warmup_steps: return step / warmup_steps # 线性增长 else: return 1.0 # 后续用其他调度 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

6.2 Cosine Annealing(余弦退火调度)

核心动机

Warmup后,用余弦函数逐步降低学习率,让模型在训练后期精细化调整参数,避免陷入局部最优。

数学公式

设总步数为 T_{\text{total}},Warmup后步数为 t' = t - T_wt' \geq 0),退火周期为 T_c = T_{\text{total}} - T_w

\eta(t) = \eta_{\text{min}} + \frac{\eta_{\text{max}} - \eta_{\text{min}}}{2} \cdot \left( 1 + \cos\left( \frac{\pi \cdot t'}{T_c} \right) \right)
  • \eta_{\text{min}} 是最小学习率(默认 \eta_{\text{max}} \times 1e-4);
  • 学习率随步数呈余弦曲线衰减,后期衰减速度变慢,便于参数微调。

组合策略(Warmup + Cosine Annealing)

LLM的标准学习率调度:

\eta(t) = \begin{cases} \frac{t}{T_w} \cdot \eta_{\text{max}} & t \leq T_w \\ \eta_{\text{min}} + \frac{\eta_{\text{max}} - \eta_{\text{min}}}{2} \cdot \left( 1 + \cos\left( \frac{\pi (t-T_w)}{T_{\text{total}}-T_w} \right) \right) & t > T_w \end{cases}

实现细节(PyTorch)

from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torch.optim.lr_scheduler import SequentialLR # 1. Warmup调度 warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: step/warmup_steps if step < warmup_steps else 1.0) # 2. Cosine Annealing调度 cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-7) # 3. 组合调度 scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps])

七、工程优化技巧:梯度累积与混合精度训练

大模型训练面临两大硬件瓶颈:

  1. 内存限制:模型参数(如100B参数模型约占400GB内存)和中间激活值无法放入单卡;
  2. 计算效率:纯FP32训练速度慢,硬件利用率低。

梯度累积和混合精度训练是解决这两个问题的核心工程技巧。

7.1 梯度累积(Gradient Accumulation):模拟大批次训练

核心动机

单卡内存有限,无法容纳大批次(如 B=64),梯度累积通过“多次小批次训练后累积梯度,再更新参数”,模拟大批次的效果。

数学原理

假设大批次大小为 B_{\text{total}} = B_{\text{mini}} \times KK 为累积步数),则:

  • 大批次梯度:\nabla L_{\text{total}} = \frac{1}{B_{\text{total}}} \sum_{k=1}^K \sum_{b=1}^{B_{\text{mini}}} \nabla L_{k,b}
  • 梯度累积:每次小批次计算梯度后不更新参数,累积 K 次后取平均,再更新参数:
    \nabla L_{\text{accum}} = \frac{1}{K} \sum_{k=1}^K \nabla L_k
    其中 \nabla L_k 是第k个小批次的梯度(\sum_{b=1}^{B_{\text{mini}}} \nabla L_{k,b})。

实现细节

accumulation_steps = 8 # 累积8个小批次,模拟大批次 model.train() for step, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.cuda(), targets.cuda() # 前向传播计算损失 loss = model(inputs, targets) # 损失缩放(平均到每个小批次) loss = loss / accumulation_steps # 反向传播计算梯度(累积梯度) loss.backward() # 累积K步后更新参数 if (step + 1) % accumulation_steps == 0: optimizer.step() # 参数更新 optimizer.zero_grad() # 梯度清零 scheduler.step() # 学习率调度

注意事项

  • 梯度清零时机:更新参数后清零,而非每次小批次后;
  • 学习率调整:模拟大批次时,学习率需按 B_{\text{total}} 缩放(大批次对应更大学习率);
  • BatchNorm适配:若使用BatchNorm,需设置 track_running_stats=False(避免小批次统计偏差)。

7.2 混合精度训练(Mixed Precision Training):平衡速度与精度

核心动机

  • FP32(单精度):精度高,但内存占用大、计算慢;
  • FP16(半精度):内存占用仅为FP32的1/2,计算速度快,但精度低(易出现梯度下溢/溢出)。

混合精度训练的核心是“用FP16存储模型参数和计算中间激活值,用FP32存储梯度和优化器状态”,平衡速度与精度。

关键技术:梯度缩放(Gradient Scaling)

FP16的动态范围小(最小正值约 6e-8),梯度值较小时会被舍入为0(梯度下溢)。解决方案是“训练时放大损失,反向传播后缩小梯度”:

数学原理

  1. 设缩放因子为 s(默认128),放大损失:L_{\text{scaled}} = s \cdot L
  2. 反向传播计算缩放梯度:\nabla L_{\text{scaled}} = s \cdot \nabla L(避免下溢);
  3. 参数更新前缩小梯度:\nabla L = \frac{\nabla L_{\text{scaled}}}{s}(恢复真实梯度)。

实现细节(PyTorch AMP)

PyTorch的torch.cuda.amp模块简化了混合精度训练的实现:

from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() # 梯度缩放器 model.train() for step, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.cuda(), targets.cuda() optimizer.zero_grad() # 开启autocast,自动使用FP16计算 with autocast(): loss = model(inputs, targets) # 反向传播+梯度缩放 scaler.scale(loss).backward() # 参数更新(自动缩放梯度) scaler.step(optimizer) # 更新缩放因子 scaler.update() scheduler.step()

优势

  • 内存节省:模型参数和激活值占用内存减少50-75%(FP16+梯度累积可训练更大模型);
  • 速度提升:GPU对FP16的计算吞吐量是FP32的2-8倍(依赖GPU架构,如A100支持TF32加速);
  • 精度损失小:FP32存储梯度和优化器状态,仅计算过程用FP16,精度接近纯FP32。

注意事项

  • 模型初始化:确保参数初始化在FP16的动态范围内(避免初始值溢出);
  • 数值稳定性检查:若出现 inf/nan,需增大缩放因子 s
  • 不支持FP16的层:如BatchNorm、Softmax,需手动指定用FP32计算(AMP会自动处理)。

发布者: 作者: 转发
评论区 (0)
U