nano-GPT 模型架构介绍 本章节将围绕nano-GPT模型的代码实现,从核心特性、配置参数、关键组件、运行流程到优化推理,系统化的拆解模型架构,从原理+实现的角度,详细介绍 nano-GPT 模型架构。 如果对于相关知识没有太多基础,想要从基础整体了解,可以先看补充内容:大语言模型架构训练核心原理全解析。正文内容更偏向于其核心要点内容的说明。 一、 模型概述 nanoChat模型是轻量简化版 GPT,基于 Transformer 解码器架构,在保留核心 autoregressive 能力的同时,通过多项工程优化降低复杂度并提升推理效率,核心特性如下: 采用旋转位置编码(Rotary Embeddings),移除传统可学习位置编码,减少参数并增强长序列泛化性;
本章节将围绕nano-GPT模型的代码实现,从核心特性、配置参数、关键组件、运行流程到优化推理,系统化的拆解模型架构,从原理+实现的角度,详细介绍 nano-GPT 模型架构。
如果对于相关知识没有太多基础,想要从基础整体了解,可以先看补充内容:大语言模型架构训练核心原理全解析。正文内容更偏向于其核心要点内容的说明。
nanoChat模型是轻量简化版 GPT,基于 Transformer 解码器架构,在保留核心 autoregressive 能力的同时,通过多项工程优化降低复杂度并提升推理效率,核心特性如下:
采用旋转位置编码(Rotary Embeddings),移除传统可学习位置编码,减少参数并增强长序列泛化性;
引入QK 归一化(QK Norm),对注意力层的查询(Q)和键(K)单独做归一化,稳定注意力计算;
解耦词嵌入与输出层权重(Untied Weights),词嵌入层(wte)与语言模型头(lm_head)使用独立权重,提升训练灵活性;
MLP 层采用ReLU² 激活函数(ReLU 后平方),替代传统 GELU,简化计算同时保留非线性表达能力;
使用无参数 RMS 归一化,所有归一化操作(norm 函数)无 learnable params,降低过拟合风险;
全链路无偏置线性层,所有 Linear 层(如 c_q、c_k、c_proj)均不设 bias,减少参数总量;
支持多查询注意力(MQA),通过重复键值头(KV Head)匹配查询头(Q Head)数量,大幅提升推理效率。
模型的所有可配置参数通过GPTConfig数据类定义,决定模型规模、结构与运行边界,参数说明如下表:
| 参数名 | 默认值 | 作用说明 |
|---|---|---|
| sequence_len | 1024 | 训练阶段支持的最大序列长度,推理时可通过 KV 缓存处理更长序列(需小于 rotary_seq_len) |
| vocab_size | 50304 | 模型支持的词汇表大小,对应词嵌入层(wte)的输入维度 |
| n_layer | 12 | Transformer 解码器层数,即 Block 模块的堆叠数量 |
| n_head | 6 | 注意力层的查询头(Q Head)数量,决定注意力的并行度 |
| n_kv_head | 6 | 注意力层的键值头(KV Head)数量,MQA 模式下需满足n_head % n_kv_head == 0 |
| n_embd | 768 | 词嵌入维度与 Transformer 层的隐藏状态维度,需满足n_embd % n_head == 0 |
模型由「基础工具函数」和「核心模块」构成,各组件职责明确且复用性强,以下按依赖关系拆解:
为核心模块提供通用计算能力,无状态且纯函数式设计:
功能:对输入张量的最后一维(特征维度)做 RMS 归一化,公式为 x / sqrt(mean(x²) + ε);
特点:无 learnable params,仅依赖输入数据本身,避免参数冗余。
功能:将预计算的 cos/sin 矩阵应用到 Q/K 张量,实现相对位置编码;
步骤:先将 Q/K 的最后一维(head_dim)拆分为两半,通过三角函数计算旋转后的值,再拼接回原维度。
功能:在 MQA 模式下,将 KV Head 的输出重复n_head//n_kv_head次,匹配 Q Head 的数量;
实现:通过张量扩展(expand)+ 重塑(reshape)完成,避免冗余计算。
模型的注意力核心,支持 MQA 与 KV 缓存,确保 autoregressive 特性(当前 token 仅关注历史 token):
输入输出:输入为隐藏状态(x)、cos/sin 矩阵(cos_sin)、KV 缓存(kv_cache),输出为注意力层的特征映射;
核心流程:
投影:通过 c_q、c_k、c_v 将输入 x 分别投影为 Q、K、V 张量,并调整维度为(B, T, H, head_dim);
位置编码:对 Q、K 应用旋转编码,再分别做 QK 归一化;
KV 缓存:若启用缓存,将当前 K、V 插入缓存,获取历史 + 当前的完整 KV 张量;
MQA 适配:重复 KV Head 以匹配 Q Head 数量;
注意力计算:根据是否启用缓存 / 序列长度,选择因果掩码(is_causal=True)或自定义掩码,调用F.scaled_dot_product_attention完成计算;
输出投影:将注意力结果重塑后,通过 c_proj 投影回隐藏状态维度(n_embd)。
注意力层之后的非线性变换模块,负责增强模型的特征表达能力:
结构:两层线性层 + ReLU² 激活,维度变换为 n_embd → 4*n_embd → n_embd;
激活函数:先通过F.relu(x),再对结果平方(square()),替代传统 GELU,计算更简洁;
无偏置:c_fc 和 c_proj 均不设 bias,减少参数并降低训练波动。
模型的基本堆叠单元,由「注意力层 + MLP 层」构成,采用残差连接:
流程:输入 x 先经过 norm 归一化,送入注意力层后与原 x 残差相加;再对结果做 norm 归一化,送入 MLP 层后再次残差相加;
特点:归一化操作在进入子层前(Pre-Norm),避免梯度消失,提升训练稳定性。
词嵌入层(wte):nn.Embedding(vocab_size, n_embd),将 token ID 转换为高维向量,默认使用 bfloat16 精度,节省内存;
语言模型头(lm_head):nn.Linear(n_embd, vocab_size),将 Transformer 输出的隐藏状态映射为词汇表概率分布,与 wte 权重解耦,且初始化时权重设为 0。
模块构建:在__init__中初始化词嵌入层(wte)、Transformer 块列表(h)、lm_head,并预计算旋转编码(cos/sin),存储为非持久化 buffer(不写入 checkpoint);
权重初始化:
线性层:采用正态分布初始化,标准差为 1/sqrt(fan_in) * min(1, sqrt(fan_out/fan_in))(参考论文 2310.17813);
嵌入层:正态分布初始化(mean=0.0, std=1.0);
特殊初始化:lm_head、注意力层 c_proj、MLP 层 c_proj 的权重初始化为 0,减少初始训练波动。
输入处理:接收 token ID(idx)和标签(targets),检查序列长度是否超过旋转编码缓存;
旋转编码截取:根据 KV 缓存位置(T0),截取当前序列所需的 cos/sin 片段;
词嵌入与归一化:idx 通过 wte 转换为嵌入向量,再经过 norm 归一化;
Transformer 堆叠计算:将嵌入向量依次送入每个 Block,每个 Block 接收 cos_sin 和 kv_cache,输出更新后的隐藏状态;
输出与损失计算:隐藏状态经 norm 后,通过 lm_head 生成 logits,应用 softcap(15*tanh (logits/15))稳定值范围,最终计算交叉熵损失(targets 存在时)。
采用朴素 autoregressive 生成,支持温度采样(temperature)和 Top-K 采样:
输入准备:将初始 token 列表转换为张量,添加 batch 维度;
循环生成:
每次前向传播仅取最后一个 token 的 logits;
若启用 Top-K,过滤掉概率低于 Top-K 最小值的 logits;
温度采样:logits 除以温度后 softmax, multinomial 采样下一个 token;贪心采样(temperature=0):取 logits 最大值;
拼接当前 token 与历史 token,重复至生成 max_tokens 个 token;
采用分组优化策略,针对不同类型参数设置不同学习率:
AdamW 优化器:用于 lm_head(unembedding_lr)和 wte(embedding_lr),学习率按 1/sqrt(n_embd/768) 缩放(适配不同隐藏维度);
Muon 优化器:用于 Transformer 块中的线性层(matrix_lr),侧重矩阵参数的稳定更新;
权重衰减(weight_decay)可配置,默认 0,减少过拟合。
词嵌入层(wte)和旋转编码(cos/sin)默认使用 bfloat16 精度,减少内存占用和计算耗时;
logits 计算时转换为 float32,确保损失计算精度。
此部分内容基于前文模型架构框架,针对四大核心组件(旋转编码、RMS 归一化、因果自注意力层、MLP 层),我们详细拆解其数学原理并给出对于的代码实现。
旋转编码是模型的位置编码方案,通过三角函数对查询(Q)和键(K)进行旋转,注入相对位置信息,核心分为 “预计算 cos/sin 矩阵” 和 “应用旋转编码” 两步。
旋转编码的核心是为每个位置t(序列中的第t个 token)和每个通道i(头维度head_dim的第i个维度)计算旋转频率,再生成对应的cos和sin值。
步骤 1:计算逆频率(inv_freq)
逆频率决定了不同通道的旋转周期,公式为:
\text{inv\_freq}[i] = \frac{1}{\text{base}^{(2i/\text{head\_dim})}}
其中:
i:通道索引(取值范围[0, 1, ..., head_dim//2 - 1],因旋转按 “两个通道一组” 进行);
base:基础周期参数(代码中默认10000);
head_dim:单个注意力头的维度(由n_embd // n_head计算得出)。
步骤 2:计算频率矩阵(freqs)
频率矩阵是位置与通道的外积,公式为:
\text{freqs}[t, i] = t \times \text{inv\_freq}[i]
其中:
t:位置索引(取值范围[0, 1, ..., seq_len - 1],seq_len为预计算的序列长度,代码中为rotary_seq_len);
矩阵形状:(seq_len, head_dim//2)(因每个通道组对应一个频率)。
步骤 3:生成 cos/sin 矩阵
对频率矩阵分别应用余弦和正弦函数,得到最终用于旋转的矩阵:
\cos_{mat}[t, i] = \cos(\text{freqs}[t, i])
\sin_{mat}[t, i] = \sin(\text{freqs}[t, i])
最终扩展维度以适配 Q/K 张量形状:
\cos = \cos_{mat} \rightarrow \text{reshape}(1, \text{seq\_len}, 1, \text{head\_dim//2}) \rightarrow \text{repeat}(1, 1, 1, 2)
\sin = \sin_{mat} \rightarrow \text{reshape}(1, \text{seq\_len}, 1, \text{head\_dim//2}) \rightarrow \text{repeat}(1, 1, 1, 2)
(注:repeat(1,1,1,2)是为了让每个通道组的cos/sin值覆盖两个通道,匹配head_dim维度)
对 Q/K 张量的每个位置 - 通道对进行旋转,公式基于复数旋转原理(实部和虚部分别对应两个通道):
设 Q 的某两个通道为q1(第2i维)和q2(第2i+1维),旋转后为q1'和q2':
q1' = q1 \times \cos_{mat}[t, i] + q2 \times \sin_{mat}[t, i]
q2' = -q1 \times \sin_{mat}[t, i] + q2 \times \cos_{mat}[t, i]
K 的旋转公式与 Q 完全一致,仅输入张量替换为 K。
_precompute_rotary_embeddings函数)def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # 1. 计算逆频率(对应公式:inv_freq[i] = 1/base^(2i/head_dim)) channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) # i的范围:0,2,...head_dim-2(共head_dim//2个) inv_freq = 1.0 / (base ** (channel_range / head_dim)) # 按公式计算逆频率,形状:(head_dim//2,) # 2. 计算频率矩阵(对应公式:freqs[t,i] = t * inv_freq[i]) t = torch.arange(seq_len, dtype=torch.float32, device=device) # t的范围:0,1,...seq_len-1 freqs = torch.outer(t, inv_freq) # 外积计算,形状:(seq_len, head_dim//2) # 3. 生成cos/sin矩阵并扩展维度(对应公式:cos_mat = cos(freqs),sin_mat = sin(freqs)) cos, sin = freqs.cos(), freqs.sin() # 形状均为:(seq_len, head_dim//2) cos, sin = cos.bfloat16(), sin.bfloat16() # 转为bfloat16节省内存 # 扩展维度:(seq_len, head_dim//2) → (1, seq_len, 1, head_dim//2) → 最终匹配Q/K的(1, T, 1, head_dim) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # 增加batch维和head维 return cos, sin
apply_rotary_emb函数)def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # Q/K张量形状:(B, T, H, head_dim),B=批次,T=序列长,H=头数 d = x.shape[3] // 2 # 每个旋转组的通道数(即head_dim//2) # 拆分Q/K为两个通道组(对应公式中q1和q2的分离) x1, x2 = x[..., :d], x[..., d:] # x1: (B,T,H,d),x2: (B,T,H,d) # 应用旋转公式(对应q1'和q2'的计算) y1 = x1 * cos + x2 * sin # (B,T,H,d),对应q1' = q1*cos + q2*sin y2 = x1 * (-sin) + x2 * cos # (B,T,H,d),对应q2' = -q1*sin + q2*cos # 拼接回原维度(恢复head_dim) out = torch.cat([y1, y2], 3) # (B,T,H,head_dim) out = out.to(x.dtype) # 确保输出与输入 dtype 一致 return out
RMS 归一化是模型的归一化方案,无学习参数,通过对张量的 “均方根” 进行归一化,稳定训练过程。
RMS 归一化针对张量的最后一维(特征维度,如head_dim或n_embd)计算,公式为:
\text{norm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}}
其中:
x:输入张量(如 Q/K 张量、隐藏状态张量);
mean(x^2):对 x 的最后一维计算元素平方的均值;
ε:数值稳定性参数(PyTorch 的F.rms_norm中默认1e-6,代码中未显式指定,使用默认值);
核心特点:无γ(缩放因子)和β(偏移因子),完全基于输入数据的统计信息,减少参数冗余。
对应norm函数,直接调用 PyTorch 原生函数实现公式:
def norm(x): # 纯函数式RMS归一化(对应公式:x / sqrt(mean(x²) + ε)) # x.ndim可灵活适配:如Q/K张量(4维)、隐藏状态张量(3维) # (x.size(-1),):指定对最后一维(特征维度)进行归一化,匹配公式中的“特征维度均值” return F.rms_norm(x, (x.size(-1),))
因果自注意力是模型的核心组件,负责捕捉序列内的依赖关系,支持多查询注意力(MQA)和 KV 缓存,核心流程包括 “投影→旋转→MQA 适配→注意力计算→输出投影”。
通过线性层将输入隐藏状态x投影为 Q、K、V 张量,公式为:
Q = x \times W_q
K = x \times W_k
V = x \times W_v
其中:
W_q:查询投影权重矩阵,形状(n_embd, n_head × head_dim);
W_k:键投影权重矩阵,形状(n_embd, n_kv_head × head_dim);
W_v:值投影权重矩阵,形状(n_embd, n_kv_head × head_dim);
投影后维度调整:将Q从(B, T, n_head×head_dim)reshape 为(B, T, n_head, head_dim),K/V 同理(K/V 为(B, T, n_kv_head, head_dim))。
MQA 模式下,KV Head 数量n_kv_head少于 Q Head 数量n_head,需将 KV 张量重复r次(r = n_head // n_kv_head),公式为:
K_{\text{MQA}} = K \rightarrow \text{repeat}(1, 1, r, 1)
V_{\text{MQA}} = V \rightarrow \text{repeat}(1, 1, r, 1)
其中:
重复维度为 “头数维度”(第 2 维,对应代码中的dim=1);
目的:让每个 Q Head 都能对应到相同的 KV Head,减少计算量。
注意力分数计算基于缩放点积注意力,同时加入因果掩码(确保当前 token 仅关注历史 token),公式为:
注意力分数:
\text{score} = \frac{Q \times K_{\text{MQA}}^T}{\sqrt{\text{head\_dim}}}
$K_{\text{MQA}}^T$表示对 K 的 “序列长度维度” 转置,即(B, H, T, head_dim) → (B, H, head_dim, T))
因果掩码:
对score矩阵应用下三角掩码(下三角及对角线为 1,上三角为 -∞),公式为:
\text{score}_{\text{masked}} = \text{score} + M
其中M为掩码矩阵,M[t1, t2] = 0(若t1 ≥ t2,即当前 tokent1可关注历史 tokent2),否则M[t1, t2] = -∞。
注意力权重与输出:
\text{attn\_weight} = \text{softmax}(\text{score}_{\text{masked}})
\text{attn\_out} = \text{attn\_weight} \times V_{\text{MQA}}
最终投影:
将多个头的输出拼接后,通过线性层投影回隐藏状态维度:
\text{out} = \text{concat}(\text{attn\_out}) \times W_{\text{proj}}
其中W_proj为输出投影权重矩阵,形状(n_head×head_dim, n_embd)。
对应CausalSelfAttention类的forward方法,按 “投影→旋转→MQA 适配→注意力计算→输出投影” 分步映射:
def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() # B=批次,T=序列长,C=n_embd head_dim = self.head_dim # 单个头维度:n_embd // n_head # 1. 投影(对应公式Q=x×W_q,K=x×W_k,V=x×W_v) q = self.c_q(x).view(B, T, self.n_head, head_dim) # (B,T,n_head,head_dim):W_q对应self.c_q权重 k = self.c_k(x).view(B, T, self.n_kv_head, head_dim) # (B,T,n_kv_head,head_dim):W_k对应self.c_k权重 v = self.c_v(x).view(B, T, self.n_kv_head, head_dim) # (B,T,n_kv_head,head_dim):W_v对应self.c_v权重 # 2. 旋转编码与QK归一化(对应旋转公式+RMS公式) cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # 应用旋转编码 q, k = norm(q), norm(k) # QK归一化(对应RMS公式) # 调整维度:(B,T,H,D) → (B,H,T,D)(为注意力计算做准备,头维度提前) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # 3. KV缓存(补充历史KV,不影响数学公式,仅工程优化) if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) Tq = q.size(2) # 当前查询序列长 Tk = k.size(2) # 历史+当前的键值序列长 # 4. MQA适配(对应KV重复公式:K_MQA = K.repeat(1,r,1,1)) nrep = self.n_head // self.n_kv_head # r = n_head / n_kv_head k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) # K/V重复后形状:(B,n_head,Tk,head_dim) # 5. 注意力计算(对应score、mask、softmax、attn_out公式) if kv_cache is None or Tq == Tk: # 训练/全序列推理:因果掩码(is_causal=True),对应score_masked公式 y = F.scaled_dot_product_attention(q, k, v, is_causal=True) elif Tq == 1: # 单token推理:无掩码(仅关注所有历史KV) y = F.scaled_dot_product_attention(q, k, v, is_causal=False) else: # 多token推理:自定义因果掩码(前缀全关注+块内因果) attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) prefix_len = Tk - Tq attn_mask[:, :prefix_len] = True # 前缀(历史)全关注 attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), device=q.device)) # 块内因果 y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # 应用自定义掩码 # 6. 输出投影(对应out=concat(attn_out)×W_proj公式) y = y.transpose(1, 2).contiguous().view(B, T, -1) # 拼接头输出:(B,H,T,D)→(B,T,n_head×D) y = self.c_proj(y) # 线性投影:W_proj对应self.c_proj权重,输出形状(B,T,n_embd) return y
MLP 层负责对注意力层的输出进行非线性变换,增强模型的特征表达能力,核心是 “线性扩展→ReLU²激活→线性压缩” 三步。
MLP 层的维度变换为n_embd → 4×n_embd → n_embd,公式如下:
线性扩展:将输入隐藏状态映射到更高维度(4 倍n_embd):
x_1 = x \times W_{\text{fc}}
其中W_fc为扩展权重矩阵,形状(n_embd, 4×n_embd)。
ReLU²激活:先应用 ReLU 函数,再对结果平方,公式为:
x_2 = \text{ReLU}(x_1)^2 = \max(x_1, 0)^2
(注:ReLU² 替代传统 GELU,计算更简洁,同时保留非线性表达能力)
线性压缩:将高维特征映射回n_embd维度:
\text{out} = x_2 \times W_{\text{proj}}
其中W_proj为压缩权重矩阵,形状(4×n_embd, n_embd)。
对应MLP类的forward方法,直接映射上述三步公式:
def forward(self, x): # 1. 线性扩展(对应公式x1 = x×W_fc) x = self.c_fc(x) # W_fc对应self.c_fc权重,输入(B,T,n_embd)→输出(B,T,4×n_embd) # 2. ReLU²激活(对应公式x2 = max(x1,0)^2) x = F.relu(x).square() # F.relu(x)是max(x,0),.square()是平方操作 # 3. 线性压缩(对应公式out = x2×W_proj) x = self.c_proj(x) # W_proj对应self.c_proj权重,输入(B,T,4×n_embd)→输出(B,T,n_embd) return x
接下来我们对于nanoChat的gpt.py代码进行分块详细解释【按代码逻辑结构展开】:
""" GPT model (rewrite, a lot simpler) Notable features: - rotary embeddings (and no positional embeddings) # 旋转位置嵌入(无传统位置嵌入) - QK norm # Q和K的归一化 - untied weights for token embedding and lm_head # 词嵌入与输出头权重不共享 - relu^2 activation in MLP # MLP使用relu平方激活 - norm after token embedding # 词嵌入后做归一化 - no learnable params in rmsnorm # RMSNorm无学习参数 - no bias in linear layers # 线性层无偏置 - Multi-Query Attention (MQA) support for more efficient inference # 支持MQA提升推理效率 """ import math from functools import partial from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW
get_dist_info)、优化器(Muon、DistAdamW)等。GPTConfig)@dataclass class GPTConfig: sequence_len: int = 1024 # 最大序列长度 vocab_size: int = 50304 # 词汇表大小 n_layer: int = 12 # Transformer层数 n_head: int = 6 # 查询头数量(Q头) n_kv_head: int = 6 # 键值头数量(K/V头,用于MQA/GQA) n_embd: int = 768 # 嵌入维度
dataclass定义模型超参数,集中管理配置。n_head与n_kv_head分离,支持Multi-Query Attention(MQA,n_kv_head=1)或Grouped-Query Attention(GQA,1 < n_kv_head < n_head),平衡性能与效率。norm)def norm(x): # Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),))
apply_rotary_emb)def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention: (B, T, H, D) d = x.shape[3] // 2 # 旋转维度的一半 x1, x2 = x[..., :d], x[..., d:] # 分割最后一维为两部分 y1 = x1 * cos + x2 * sin # 旋转第一部分 y2 = x1 * (-sin) + x2 * cos # 旋转第二部分 out = torch.cat([y1, y2], 3) # 拼接回原维度 out = out.to(x.dtype) # 保证输出与输入 dtype 一致 return out
CausalSelfAttention)class CausalSelfAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx # 层索引(用于KV缓存) self.n_head = config.n_head self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head # 每个头的维度 assert self.n_embd % self.n_head == 0 # 嵌入维度需被头数整除 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 # KV头需为Q头的约数(GQA/MQA) # 线性投影层(无偏置) self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) # Q投影 self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) # K投影 self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) # V投影 self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) # 注意力输出投影 def forward(self, x, cos_sin, kv_cache): B, T, C = x.size() # B: batch size, T: seq len, C: embedding dim # 投影得到Q、K、V q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # (B, T, n_head, head_dim) k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) # (B, T, n_kv_head, head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) # (B, T, n_kv_head, head_dim) # 应用旋转嵌入到Q和K cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) # QK归一化(增强稳定性) # 转置维度:(B, T, H, D) → (B, H, T, D)(头维度提前,便于注意力计算) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # 处理KV缓存(推理时复用历史K/V,减少计算) if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) # 将当前K/V插入缓存,获取完整K/V Tq = q.size(2) # 当前查询序列长度 Tk = k.size(2) # 总键值序列长度(缓存+当前) # 注意力计算(分场景处理因果掩码) enable_gqa = self.n_head != self.n_kv_head # 是否启用GQA(Q头与KV头数量不同) if kv_cache is None or Tq == Tk: # 训练模式(无缓存)或推理时序列长度匹配:直接用因果掩码 y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) elif Tq == 1: # 推理时单token生成:查询需关注所有历史键值(无因果掩码) y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) else: # 推理时多token生成:前缀部分可全关注,当前块内用因果掩码 attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) prefix_len = Tk - Tq # 历史前缀长度 if prefix_len > 0: attn_mask[:, :prefix_len] = True # 允许关注前缀 attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) # 块内因果掩码 y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) # 拼接头并投影回残差流 y = y.transpose(1, 2).contiguous().view(B, T, -1) # (B, H, T, D) → (B, T, H*D) = (B, T, C) y = self.c_proj(y) return y
MLP)class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) # 升维 self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) # 降维回残差流 def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() # relu^2激活(比单纯relu更平滑) x = self.c_proj(x) return x
relu(x)^2(平方ReLU),相比标准ReLU更平滑,可能提升训练稳定性。Block)class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.attn = CausalSelfAttention(config, layer_idx) # 注意力子模块 self.mlp = MLP(config) # MLP子模块 def forward(self, x, cos_sin, kv_cache): # 残差连接:输入 + 注意力输出(注意力前先归一化) x = x + self.attn(norm(x), cos_sin, kv_cache) # 残差连接:输入 + MLP输出(MLP前先归一化) x = x + self.mlp(norm(x)) return x
GPT)class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict({ "wte": nn.Embedding(config.vocab_size, config.n_embd), # 词嵌入表 "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), # Transformer块列表 }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # 输出头(预测下一个token) # 预计算旋转嵌入(缓存足够长的序列,避免动态计算) self.rotary_seq_len = config.sequence_len * 10 # 超量预计算(10倍序列长度) head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # 不保存到checkpoint self.register_buffer("sin", sin, persistent=False) def init_weights(self): self.apply(self._init_weights) # 特殊初始化:输出头和投影层置零 torch.nn.init.zeros_(self.lm_head.weight) for block in self.transformer.h: torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.attn.c_proj.weight) # 重新初始化旋转嵌入(确保设备正确) head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin # 词嵌入转bf16(节省内存) if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) def _init_weights(self, module): if isinstance(module, nn.Linear): # 线性层初始化:根据输入维度调整标准差(参考论文https://arxiv.org/pdf/2310.17813) fan_out = module.weight.size(0) fan_in = module.weight.size(1) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) torch.nn.init.normal_(module.weight, mean=0.0, std=std) elif isinstance(module, nn.Embedding): # 嵌入层初始化:正态分布 torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
wte)、Transformer块列表(h)、输出头(lm_head)。_precompute_rotary_embeddings提前计算cos/sin值,避免推理时重复计算,提升效率。_precompute_rotary_embeddings)def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): if device is None: device = self.transformer.wte.weight.device # 自动获取设备 channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) # 步长2取通道(旋转对) inv_freq = 1.0 / (base ** (channel_range / head_dim)) # 频率倒数(base=10000为默认) t = torch.arange(seq_len, dtype=torch.float32, device=device) # 时间步 freqs = torch.outer(t, inv_freq) # 外积计算 (seq_len, head_dim/2) cos, sin = freqs.cos(), freqs.sin() # 计算cos和sin cos, sin = cos.bfloat16(), sin.bfloat16() # 转bf16节省内存 # 增加batch和头维度,便于广播:(1, seq_len, 1, head_dim/2*2) cos, sin = cos[None, :, None, :], sin[None, :, None, :] return cos, sin
apply_rotary_emb函数。forward)def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): B, T = idx.size() # idx: (B, T) 输入token序列 # 获取当前序列的旋转嵌入(考虑KV缓存的偏移) assert T <= self.cos.size(1), f"序列长度超过旋转嵌入缓存:{T} > {self.cos.size(1)}" T0 = 0 if kv_cache is None else kv_cache.get_pos() # 缓存偏移量(推理时从历史位置开始) cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # 截取对应长度的cos/sin # Transformer前向 x = self.transformer.wte(idx) # 词嵌入:(B, T) → (B, T, n_embd) x = norm(x) # 嵌入后归一化 for block in self.transformer.h: x = block(x, cos_sin, kv_cache) # 经过所有Transformer块 x = norm(x) # 最终归一化 # 计算输出(训练时返回损失,推理时返回logits) softcap = 15 # logits截断系数(稳定训练) if targets is not None: # 训练模式:计算交叉熵损失 logits = self.lm_head(x) # (B, T, vocab_size) logits = softcap * torch.tanh(logits / softcap) # 截断logits范围(避免过大值) logits = logits.float() # 用fp32计算损失(精度更高) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), # 展平为(B*T, vocab_size) targets.view(-1), # 展平为(B*T,) ignore_index=-1, # 忽略-1标签 reduction=loss_reduction ) return loss else: # 推理模式:返回logits logits = self.lm_head(x) logits = softcap * torch.tanh(logits / softcap) return logits
softcap):通过tanh限制logits范围,避免梯度爆炸。setup_optimizers)def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() # 获取分布式信息 # 参数分组:矩阵参数(Transformer块)、嵌入参数、输出头参数 matrix_params = list(self.transformer.h.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) # 学习率按模型维度缩放(∝1/√dmodel,适配不同嵌入维度) dmodel_lr_scale = (model_dim / 768) ** -0.5 if rank == 0: print(f"学习率缩放系数:∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") # AdamW优化器(用于嵌入和输出头) adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), ] adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) # Muon优化器(用于Transformer块的线性层,可能更适合大模型) muon_kwargs = dict(lr=matrix_lr, momentum=0.95) MuonFactory = DistMuon if ddp else Muon muon_optimizer = MuonFactory(matrix_params,** muon_kwargs) # 保存初始学习率 for opt in [adamw_optimizer, muon_optimizer]: for group in opt.param_groups: group["initial_lr"] = group["lr"] return [adamw_optimizer, muon_optimizer]
generate)@torch.inference_mode() def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): """自回归生成:从输入tokens开始,生成最多max_tokens个token""" assert isinstance(tokens, list) device = self.get_device() rng = None if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) # 随机数生成器(固定种子确保可复现) ids = torch.tensor([tokens], dtype=torch.long, device=device) # 转为tensor并添加batch维度 for _ in range(max_tokens): logits = self.forward(ids) # (B, T, vocab_size) logits = logits[:, -1, :] # 取最后一个token的logits:(B, vocab_size) # Top-K过滤 if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # 过滤掉top_k外的token # 采样或贪心选择 if temperature > 0: logits = logits / temperature # 温度调整(温度越高随机性越强) probs = F.softmax(logits, dim=-1) next_ids = torch.multinomial(probs, num_samples=1, generator=rng) # 多项式采样 else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) # 贪心选择 ids = torch.cat((ids, next_ids), dim=1) # 拼接新token yield next_ids.item() # 返回生成的token
@torch.inference_mode()):禁用梯度计算,提升效率。max_tokens。该代码实现了一个简化但高效的GPT模型,核心特点包括:
代码实现的nanoGPT模型,本质是优化后的Transformer解码器堆叠结构,我们围绕“效率、稳定性、可扩展性”三大核心目标,结合现代大模型训练的关键技术(如RoPE、MQA、RMSNorm),讲解其中的一些关键知识点。
GPT类模型的核心是Transformer解码器的重复堆叠,标准Transformer解码器包含“多头因果自注意力”“前馈网络(MLP)”“残差连接”“归一化”四大组件。代码中的Block类正是这一结构的实现,但做了关键优化,先明确基础框架再拆解细节:
标准Transformer解码器(如GPT-2)的单块结构为:
输入 → LayerNorm → 多头因果自注意力 → 残差连接 → LayerNorm → MLP → 残差连接 → 输出
即Pre-Norm(归一化在前) 设计,区别于早期Transformer的Post-Norm(归一化在后)。
代码中的Block完全遵循Pre-Norm,但做了两处关键简化:
class Block(nn.Module): def forward(self, x, cos_sin, kv_cache): x = x + self.attn(norm(x), cos_sin, kv_cache) # Norm→Attention→残差 x = x + self.mlp(norm(x)) # Norm→MLP→残差 return x
无偏置线性层:所有Linear层(如c_q、c_fc)均设bias=False。
原理:Transformer中残差连接和归一化已能稳定训练,偏置项对性能提升有限,却会增加1/3的参数量(如nn.Linear(768,768)无偏置时参数768²=589824,有偏置时多768个)。
延展:现代大模型(如LLaMA、Falcon)均移除线性层偏置,成为行业默认实践。
无参数RMSNorm:用F.rms_norm替代带可学习参数(γ、β)的LayerNorm。
原理:RMSNorm计算为 x / sqrt(mean(x²) + ε),仅依赖输入统计信息,无额外参数;LayerNorm则是 γ*(x-mean)/std + β,多2个可学习参数/层。
延展:RMSNorm比LayerNorm快30%+(减少参数读取和计算),且在大模型中性能与LayerNorm持平,已成为高效模型的首选(如LLaMA 2、Mistral)。
传统GPT用可学习位置嵌入(如GPT-2)或正弦位置嵌入(如原始Transformer),代码则用旋转位置嵌入(Rotary Position Embedding, RoPE),彻底改变了位置信息的注入方式。
RoPE的本质是通过三角函数旋转,让注意力分数仅依赖token的相对位置,而非绝对位置。
apply_rotary_emb):head_dim拆分为两半(d = head_dim//2),每对维度(2i, 2i+1)视为复数的“实部+虚部”;t个token的Q向量q_t,注入位置信息等价于乘以复数旋转因子 e^(iθ_t)(θ_t为位置t的角度):
代码中对应:q_t' = q_t × e^(iθ_t) = [q_t1*cosθ_t - q_t2*sinθ_t, q_t1*sinθ_t + q_t2*cosθ_t]
y1 = x1 * cos + x2 * sin # 实部 y2 = x1 * (-sin) + x2 * cos # 虚部
i与j的注意力分数时,旋转因子的差值为 e^(i(θ_i - θ_j)) = e^(iθ_{i-j}),仅依赖相对位置i-j,完美捕捉相对依赖。代码在__init__中预计算10倍序列长度的cos/sin矩阵(self.rotary_seq_len = config.sequence_len * 10),避免推理时实时计算:
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False)
| 位置编码类型 | 优点 | 缺点 | 代表模型 |
|---|---|---|---|
| 可学习位置嵌入 | 适配任务能力强 | 长度外推差(超训练长度失效) | GPT-2、GPT-3 |
| 正弦位置嵌入 | 长度外推好 | 相对位置捕捉弱 | 原始Transformer |
| RoPE | 相对位置精准+外推好 | 需拆分维度(head_dim需为偶数) | LLaMA、Qwen、Mistral |
标准Transformer用多头注意力(MHA),每个Q头对应独立的K/V头,参数和计算量巨大。代码通过n_head与n_kv_head分离,支持多查询注意力(MQA) 和分组查询注意力(GQA),大幅优化推理效率。
假设模型n_head=12,head_dim=64(n_embd=768):
| 注意力类型 | KV头数量(n_kv_head) | KV参数量(每层) | 推理速度 | 精度 |
|---|---|---|---|---|
| MHA(标准) | 12(与Q头相等) | 12×64=768 | 慢 | 高 |
| MQA(多查询) | 1(所有Q头共享) | 1×64=64 | 快 | 中 |
| GQA(分组查询) | 4(3个Q头共享1个KV头) | 4×64=256 | 中 | 高 |
CausalSelfAttention):self.n_head != self.n_kv_head启用GQA(enable_gqa=True);F.scaled_dot_product_attention,原生支持MQA/GQA(通过enable_gqa参数),避免手动重复KV头(早期需用repeat_kv函数,代码中已简化);标准Transformer的MLP用GELU激活(如GPT-3、BERT),代码则用ReLU²激活(F.relu(x).square()),在简化计算的同时保留非线性表达能力。
| 激活函数 | 计算复杂度 | 梯度特性 | 代表模型 |
|---|---|---|---|
| GELU | 高(需erf) | 平滑(可导性好) | GPT-3、LLaMA 1 |
| ReLU | 低 | 0点不可导(死神经元) | 早期CNN |
| ReLU² | 低(ReLU+平方) | 0点可导(更平滑) | 代码模型、部分轻量LLM |
MLP):erf函数(GELU的核心),GPU上可提速20%+;延展:
传统GPT(如GPT-2)共享词嵌入层(wte)和输出头(lm_head)的权重(lm_head.weight = wte.weight.T),代码则完全解耦:
self.transformer = nn.ModuleDict({"wte": nn.Embedding(...)}) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
解耦的原因:
lm_head针对性优化高频token的预测精度,实验表明可提升1-3%的困惑度(Perplexity)。延展【权重贡献优势】:
代码中线性层的初始化参考论文《A Spectral Condition for Feature Learning》【2310.17813】,采用自适应标准差:
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) torch.nn.init.normal_(module.weight, mean=0.0, std=std)
原理:
fan_out > fan_in(如MLP的升维层:c_fc从768→3072,fan_out=3072,fan_in=768),sqrt(fan_out/fan_in)=2,标准差放大2倍,避免信号衰减;fan_out ≤ fan_in(如降维层:c_proj从3072→768),min(1, ...)取1,标准差为1/√fan_in,避免信号放大导致梯度爆炸。延展:
代码在推理时用kv_cache存储历史的K和V,避免每次生成新token时重新计算所有历史的QK注意力:
if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) # 插入当前KV到缓存
原理(以生成100个token为例):
1+2+...+100=5050次;100次,效率提升50倍。延展:KV缓存的内存挑战与优化
B×n_kv_head×T×head_dim的K和V(B=批次,T=生成长度),生成1024个token时,单批次768维模型的KV缓存约为1×6×1024×64≈4MB/层,12层约48MB,可控;代码中的nanoGPT,是现代高效LLM架构的“缩影”,其设计思路与行业主流模型(如LLaMA、Mistral)高度一致,核心可总结为:
我们可以通过理解nanoChat该简化模型的架构,可快速掌握现代LLM的核心设计思想,为后续学习复杂模型(如Mistral、Qwen)打下基础。
大语言模型(LLM)的架构与训练本质是“基于Transformer架构的深层神经网络优化过程”,涉及模型结构设计、注意力机制、归一化与残差连接、优化器、学习率调度、工程优化六大核心模块。本部分补充内容将在前文的基础上,进一步将逐一拆解每个模块的算法原理,从“动机→数学推导→直观解释→实现细节”层层递进,让大家能更好理解其来龙去脉。
传统序列建模(RNN/LSTM/GRU)存在两大痛点:
Transformer的核心创新是用“自注意力机制”替代RNN的串行依赖建模,实现:
注:所以从某种角度来说,Transformer并不一定是最优的架构,只是我们现在找到的从效率角度来说“最好的架构”,包括为什么在大模型中舍弃编码器(Encoder),只保留解码器堆叠,也是从同参数量下的效果经济性角度考虑的。
Transformer由编码器(Encoder) 和解码器(Decoder) 组成,LLM(如GPT系列)仅使用解码器堆叠(自回归生成),其结构定义为:
其中:
现代LLM均采用Pre-Norm(归一化在前) 设计(区别于原始Transformer的Post-Norm),单个解码器块的计算流程为:
class DecoderBlock(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.norm1 = LayerNorm(d_model) # Pre-Norm self.attn = CausalSelfAttention(d_model, n_head) self.norm2 = LayerNorm(d_model) self.mlp = MLP(d_model, d_ff) self.dropout = nn.Dropout(dropout) def forward(self, x, kv_cache=None): # 自注意力支路 x1 = self.norm1(x) attn_out = self.attn(x1, kv_cache) x = x + self.dropout(attn_out) # 残差+dropout # MLP支路 x2 = self.norm2(x) mlp_out = self.mlp(x2) x = x + self.dropout(mlp_out) # 残差+dropout return x
自注意力是Transformer的核心,其本质是“通过token间的相似度计算权重,对全局信息加权求和”,实现全局依赖捕捉。
对序列中每个token x_i,希望生成一个新的表示 y_i,该表示融合自身信息和序列中其他token的相关信息:
其中 \alpha_{i,j} 是token i 对token j 的注意力权重(\sum_j \alpha_{i,j}=1)。
为了计算注意力权重,需将输入 X 投影为三个向量:
投影公式:
其中:
用“点积”衡量 Q_i(第i个token的查询)与 K_j(第j个token的键)的相似度:
对分数除以 \sqrt{D_k},使分数的方差归一化到1附近:
LLM是自回归模型,生成第i个token时,只能依赖前i-1个token,需屏蔽后验信息(j > i 的token):
将掩码后的分数归一化为概率分布(权重和为1):
用注意力权重对 V 加权求和,得到每个token的最终表示:
其中 \alpha \in \mathbb{R}^{B \times T \times T} 是注意力权重矩阵。
将 W_Q, W_K, W_V 拆分为 n_{\text{head}} 个独立的投影矩阵:
每个头的QKV投影:
其中 W_{Qh} \in \mathbb{R}^{D \times D_k},D_k = D/n_{\text{head}}(确保总维度不变:n_{\text{head}} \times D_k = D)。
每个头独立计算缩放点积注意力:
将所有头的输出拼接,再通过一个线性层投影回 D 维度:
其中 W_O \in \mathbb{R}^{D \times D} 是输出投影矩阵。
多头注意力的输出可以表示为:
本质是“多个不同的注意力机制的集成”,提升模型的表达能力。
class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_head): super().__init__() self.n_head = n_head self.d_k = d_model // n_head # 每个头的维度 # QKV投影层(无偏置) self.w_q = nn.Linear(d_model, d_model, bias=False) self.w_k = nn.Linear(d_model, d_model, bias=False) self.w_v = nn.Linear(d_model, d_model, bias=False) # 输出投影层 self.w_o = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, D = x.shape # 1. QKV投影:(B,T,D) → (B,T,D) q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) # 2. 多头拆分:(B,T,D) → (B,n_head,T,d_k) q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2) v = v.view(B, T, self.n_head, self.d_k).transpose(1, 2) # 3. 因果掩码:(T,T),上三角为True(屏蔽) mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() # 4. 缩放点积注意力(PyTorch原生实现,高效) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=True ) # 5. 多头拼接:(B,n_head,T,d_k) → (B,T,D) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D) # 6. 输出投影 out = self.w_o(attn_out) return out
自注意力机制是无序的(QK^T 计算不依赖token的位置),但语言序列的顺序至关重要(如“我吃苹果”≠“苹果吃我”),因此需要手动注入位置信息,这就是为什么设计位置编码。
位置编码 PE \in \mathbb{R}^{T \times D} 需满足:
对每个位置 t(0-based)和维度 i(0-based),位置编码为:
假设两个位置 t 和 t+k,其编码满足三角函数的和角公式:
即 PE(t+k) 可表示为 PE(t) 的线性组合,模型可通过学习权重捕捉相对位置 k。
RoPE是当前LLM的首选位置编码(如GPT-4、LLaMA、Qwen),核心是“通过旋转QK向量注入位置信息”,完美满足相对位置捕捉和长序列外推。
apply_rotary_emb)旋转因子 e^{-i\theta_t} = \cos\theta_t - i\sin\theta_t,对Q的实部和虚部进行旋转:
这正是代码中的实现逻辑:
def apply_rotary_emb(x, cos, sin): d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # 实部、虚部 y1 = x1 * cos + x2 * sin # 旋转后实部 y2 = -x1 * sin + x2 * cos # 旋转后虚部 return torch.cat([y1, y2], dim=-1)
def precompute_rope(d_model, n_head, max_seq_len=2048): d_k = d_model // n_head # 计算频率:theta_i = 10000^(-2i/d_k) theta = 1.0 / (10000 ** (torch.arange(0, d_k, 2) / d_k)) # 计算位置t的theta_t = t * theta_i t = torch.arange(max_seq_len) freqs = torch.outer(t, theta) # (max_seq_len, d_k//2) # 生成cos和sin矩阵(扩展维度适配QK) cos = freqs.cos().unsqueeze(0).unsqueeze(1) # (1,1,max_seq_len,d_k//2) sin = freqs.sin().unsqueeze(0).unsqueeze(1) # 重复维度,匹配d_k(每个频率对应实部+虚部) cos = cos.repeat(1, 1, 1, 2) # (1,1,max_seq_len,d_k) sin = sin.repeat(1, 1, 1, 2) return cos, sin
深层网络训练时,每一层的输入分布会随着前一层参数的更新而变化,导致:
层归一化的目标是“将每一层的输入归一化到固定分布(均值0,方差1)”,缓解分布漂移。
对输入张量 X \in \mathbb{R}^{B \times T \times D},在特征维度(D) 上归一化:
其中 \gamma \in \mathbb{R}^D(缩放因子)、\beta \in \mathbb{R}^D(偏移因子)是可学习参数,允许模型恢复原始分布的表达能力(若 \gamma=\sqrt{\sigma^2+\epsilon}、\beta=\mu,则归一化失效)。
LLM中常省略 \gamma 和 \beta,用更简单的RMS Norm(Root Mean Square Norm):
| 特性 | Layer Norm | Batch Norm |
|---|---|---|
| 归一化维度 | 特征维度(D) | 批次维度(B) |
| 批次依赖性 | 无(单样本即可计算) | 有(依赖批次内样本) |
| 序列长度适应性 | 好(T可任意变化) | 差(T固定时效果好) |
| 大模型场景适用性 | 强(并行训练+长序列) | 弱(批次波动+长序列失效) |
深层网络中,梯度通过链式法则传递时会被不断相乘,若梯度绝对值小于1,深层梯度会趋近于0(梯度消失)。残差连接通过“直接传递原始输入”,让梯度可直接回传。
对网络层 F(如注意力层、MLP),残差连接定义为:
假设损失为 L,对参数 \theta 的梯度为:
# 残差连接+归一化(Pre-Norm) x = x + self.dropout(self.attn(self.norm1(x))) x = x + self.dropout(self.mlp(self.norm2(x)))
优化器的核心是“根据梯度调整参数,最小化损失函数”,LLM的损失函数为交叉熵损失(预测下一个token的概率):
其中 y_{b,t} 是第b个样本第t个token的真实标签,P 是模型预测的概率分布。
参数 \theta 的更新规则:
Adam(Adaptive Moment Estimation)融合了“动量(Momentum)”和“自适应学习率(RMSProp)”,是LLM训练的基础优化器。
给定梯度 \nabla L_t = \nabla L(\theta_t),迭代更新:
一阶矩(动量):
二阶矩(自适应学习率):
偏差修正(初期修正):
参数更新:
Adam的核心缺陷是“权重衰减(Weight Decay)”的实现错误,AdamW通过“将权重衰减与梯度更新分离”修复了这一问题,成为LLM的标准优化器。
权重衰减本质是L2正则化,通过惩罚大权重,减少模型过拟合:
其中 \lambda 是权重衰减系数(默认0.01)。
Adam的错误实现:将权重衰减融入梯度更新(等价于对梯度加惩罚):
AdamW的正确实现:权重衰减与梯度更新分离(直接惩罚权重):
# AdamW优化器配置(LLM常用参数) optimizer = torch.optim.AdamW( model.parameters(), lr=2e-5, # 学习率 betas=(0.9, 0.95),# 一阶矩、二阶矩系数(LLM中β2常用0.95) weight_decay=0.01,# 权重衰减系数 eps=1e-8 # 数值稳定性参数 )
学习率是优化器的核心超参数:
LLM常用“Warmup + Cosine Annealing”组合调度策略,平衡“初期稳定”和“后期精细化调整”。
训练初期,模型参数随机初始化,梯度方差大,若用较大学习率会导致参数震荡。Warmup阶段逐步增大学习率,让模型适应梯度分布。
设Warmup步数为 T_w,初始学习率为 \eta_0,目标学习率为 \eta_{\text{max}}:
def warmup_lr_scheduler(optimizer, warmup_steps): def lr_lambda(step): if step < warmup_steps: return step / warmup_steps # 线性增长 else: return 1.0 # 后续用其他调度 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
Warmup后,用余弦函数逐步降低学习率,让模型在训练后期精细化调整参数,避免陷入局部最优。
设总步数为 T_{\text{total}},Warmup后步数为 t' = t - T_w(t' \geq 0),退火周期为 T_c = T_{\text{total}} - T_w:
LLM的标准学习率调度:
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from torch.optim.lr_scheduler import SequentialLR # 1. Warmup调度 warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: step/warmup_steps if step < warmup_steps else 1.0) # 2. Cosine Annealing调度 cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-7) # 3. 组合调度 scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps])
大模型训练面临两大硬件瓶颈:
梯度累积和混合精度训练是解决这两个问题的核心工程技巧。
单卡内存有限,无法容纳大批次(如 B=64),梯度累积通过“多次小批次训练后累积梯度,再更新参数”,模拟大批次的效果。
假设大批次大小为 B_{\text{total}} = B_{\text{mini}} \times K(K 为累积步数),则:
accumulation_steps = 8 # 累积8个小批次,模拟大批次 model.train() for step, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.cuda(), targets.cuda() # 前向传播计算损失 loss = model(inputs, targets) # 损失缩放(平均到每个小批次) loss = loss / accumulation_steps # 反向传播计算梯度(累积梯度) loss.backward() # 累积K步后更新参数 if (step + 1) % accumulation_steps == 0: optimizer.step() # 参数更新 optimizer.zero_grad() # 梯度清零 scheduler.step() # 学习率调度
track_running_stats=False(避免小批次统计偏差)。混合精度训练的核心是“用FP16存储模型参数和计算中间激活值,用FP32存储梯度和优化器状态”,平衡速度与精度。
FP16的动态范围小(最小正值约 6e-8),梯度值较小时会被舍入为0(梯度下溢)。解决方案是“训练时放大损失,反向传播后缩小梯度”:
PyTorch的torch.cuda.amp模块简化了混合精度训练的实现:
from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() # 梯度缩放器 model.train() for step, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.cuda(), targets.cuda() optimizer.zero_grad() # 开启autocast,自动使用FP16计算 with autocast(): loss = model(inputs, targets) # 反向传播+梯度缩放 scaler.scale(loss).backward() # 参数更新(自动缩放梯度) scaler.step(optimizer) # 更新缩放因子 scaler.update() scheduler.step()
inf/nan,需增大缩放因子 s;