Transformer架构与LLM解码策略深度解析


文档摘要

Transformer架构与LLM解码策略深度解析 Transformer架构是现代大语言模型的基础,理解其工作原理对于掌握LLM至关重要。本文将深入探讨Transformer架构设计和高效的解码策略。 Transformer架构核心 自注意力机制 自注意力(Self-Attention)是Transformer的核心组件,允许模型在处理序列时关注不同位置的信息。 位置编码 由于注意力机制本身不包含序列位置信息,Transformer需要显式添加位置编码: 前馈神经网络 Transformer块 解码策略 贪婪解码 束搜索(Beam Search) 采样解码 高级优化技术 KV Cache KV Cache缓存了历史token的Key和Value,避免重复计算: Flash

Transformer架构与LLM解码策略深度解析

Transformer架构是现代大语言模型的基础,理解其工作原理对于掌握LLM至关重要。本文将深入探讨Transformer架构设计和高效的解码策略。

Transformer架构核心

自注意力机制

自注意力(Self-Attention)是Transformer的核心组件,允许模型在处理序列时关注不同位置的信息。

import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def scaled_dot_product_attention(self, Q, K, V, mask=None): # Q, K, V shape: (batch_size, num_heads, seq_len, d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.d_k, dtype=torch.float32) ) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention_weights = F.softmax(scores, dim=-1) output = torch.matmul(attention_weights, V) return output, attention_weights def forward(self, query, key, value, mask=None): batch_size = query.size(0) # 线性变换并分割为多头 Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 自注意力计算 attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask) # 拼接多头 attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, -1, self.d_model ) # 输出投影 output = self.W_o(attn_output) return output, attn_weights # 使用示例 d_model = 512 num_heads = 8 mha = MultiHeadAttention(d_model, num_heads) # 假设输入是batch_size=32, seq_len=10, d_model=512 input_tensor = torch.randn(32, 10, d_model) output, weights = mha(input_tensor, input_tensor, input_tensor) print(f"Output shape: {output.shape}") # (32, 10, 512)

位置编码

由于注意力机制本身不包含序列位置信息,Transformer需要显式添加位置编码:

class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() # 创建位置编码矩阵 pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): # x shape: (batch_size, seq_len, d_model) return x + self.pe[:, :x.size(1), :] # 使用旋转位置编码(RoPE)- 更现代的位置编码方式 class RotaryPositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=2048): super().__init__() self.d_model = d_model # 计算旋转角度 inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)) self.register_buffer('inv_freq', inv_freq) def forward(self, x, seq_len): # x shape: (batch_size, seq_len, d_model) t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb def apply_rotary_pos_emb(q, k, cos, sin): # 应用旋转位置编码到查询和键 q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed

前馈神经网络

class FeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.activation = nn.GELU() # 使用GELU激活函数 def forward(self, x): return self.linear2(self.dropout(self.activation(self.linear1(x))))

Transformer块

class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(d_model, num_heads) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = FeedForward(d_model, d_ff, dropout) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # 多头自注意力 + 残差连接 + 层归一化 attn_output, _ = self.attention(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # 前馈网络 + 残差连接 + 层归一化 ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x

解码策略

1. 贪婪解码

def greedy_decode(model, input_ids, max_length, eos_token_id): """每次选择概率最高的token""" batch_size = input_ids.size(0) for _ in range(max_length): # 前向传播 outputs = model(input_ids) logits = outputs.logits[:, -1, :] # 获取最后一个位置的logits # 选择概率最高的token next_token = torch.argmax(logits, dim=-1, keepdim=True) # 拼接新生成的token input_ids = torch.cat([input_ids, next_token], dim=-1) # 检查是否生成结束符 if (next_token == eos_token_id).all(): break return input_ids

2. 束搜索(Beam Search)

def beam_search_decode(model, input_ids, beam_width, max_length, eos_token_id): """维护top-k个候选序列""" batch_size = input_ids.size(0) # 初始化:每个序列的得分和索引 sequences = [input_ids] scores = torch.zeros(batch_size, 1).to(input_ids.device) for step in range(max_length): all_candidates = [] for seq_idx, sequence in enumerate(sequences): # 前向传播 outputs = model(sequence) logits = outputs.logits[:, -1, :] log_probs = F.log_softmax(logits, dim=-1) # 获取top-k个token top_k_probs, top_k_ids = torch.topk(log_probs, beam_width, dim=-1) for k in range(beam_width): candidate = torch.cat([ sequence, top_k_ids[:, k:k+1] ], dim=-1) candidate_score = scores[:, seq_idx:seq_idx+1] + top_k_probs[:, k:k+1] all_candidates.append((candidate, candidate_score)) # 按得分排序并选择top-k ordered = sorted(all_candidates, key=lambda x: x[1].max(), reverse=True) sequences = [x[0] for x in ordered[:beam_width]] scores = torch.cat([x[1] for x in ordered[:beam_width]], dim=1) # 检查是否所有序列都生成结束符 if all((seq[:, -1] == eos_token_id).all() for seq in sequences): break return sequences[0] # 返回得分最高的序列

3. 采样解码

def temperature_sampling_decode(model, input_ids, max_length, temperature=1.0, eos_token_id=None): """使用温度参数控制采样随机性""" for _ in range(max_length): outputs = model(input_ids) logits = outputs.logits[:, -1, :] / temperature # 应用softmax并采样 probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=-1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids def top_k_sampling_decode(model, input_ids, max_length, k=50, eos_token_id=None): """只从概率最高的k个token中采样""" for _ in range(max_length): outputs = model(input_ids) logits = outputs.logits[:, -1, :] # 获取top-k top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1) top_k_probs = F.softmax(top_k_logits, dim=-1) # 从top-k中采样 next_token_idx = torch.multinomial(top_k_probs, num_samples=1) next_token = torch.gather(top_k_indices, -1, next_token_idx) input_ids = torch.cat([input_ids, next_token], dim=-1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids def top_p_sampling_decode(model, input_ids, max_length, p=0.9, eos_token_id=None): """核采样:从累积概率达到p的最小token集合中采样""" for _ in range(max_length): outputs = model(input_ids) logits = outputs.logits[:, -1, :] probs = F.softmax(logits, dim=-1) # 按概率排序 sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) # 计算累积概率 cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # 移除累积概率超过p的token sorted_indices_to_remove = cumulative_probs > p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # 设置被移除token的概率为0 sorted_probs[sorted_indices_to_remove] = 0.0 # 重新归一化 sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) # 采样 next_token_idx = torch.multinomial(sorted_probs, num_samples=1) next_token = torch.gather(sorted_indices, -1, next_token_idx) input_ids = torch.cat([input_ids, next_token], dim=-1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids

高级优化技术

1. KV Cache

KV Cache缓存了历史token的Key和Value,避免重复计算:

class KVCache: def __init__(self): self.cache = {} def get(self, layer_idx, head_idx): key = (layer_idx, head_idx) return self.cache.get(key, None) def update(self, layer_idx, head_idx, k, v): key = (layer_idx, head_idx) if key in self.cache: cached_k, cached_v = self.cache[key] self.cache[key] = ( torch.cat([cached_k, k], dim=-2), torch.cat([cached_v, v], dim=-2) ) else: self.cache[key] = (k, v) def clear(self): self.cache.clear() # 在Transformer中使用KV Cache def forward_with_cache(self, x, cache): batch_size, seq_len, d_model = x.shape # 计算Q, K, V Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) # 拆分多头 Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # 对每个头使用缓存 for head_idx in range(self.num_heads): cached = cache.get(self.layer_idx, head_idx) if cached is not None: cached_k, cached_v = cached K[:, head_idx] = torch.cat([cached_k[:, head_idx], K[:, head_idx]], dim=1) V[:, head_idx] = torch.cat([cached_v[:, head_idx], V[:, head_idx]], dim=1) # 更新缓存 cache.update(self.layer_idx, head_idx, K[:, head_idx:head_idx+1], V[:, head_idx:head_idx+1]) # 继续注意力计算...

2. Flash Attention

Flash Attention通过IO感知的精确注意力算法显著提升性能:

# 以下是概念性实现,实际实现使用CUDA内核 class FlashAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.num_heads = num_heads self.d_k = d_model // num_heads def forward(self, q, k, v): # Flash Attention通过分块计算和内存优化 # 显著减少内存访问和提升速度 # 实际使用时推荐使用: # from flash_attn import flash_attn_func # output = flash_attn_func(q, k, v) pass

总结

Transformer架构通过自注意力机制和位置编码,为序列建模提供了强大的工具。结合高效的解码策略(如束搜索、核采样)和优化技术(如KV Cache、Flash Attention),可以构建高性能的大语言模型。理解这些核心概念,对于LLM的开发和优化至关重要。


发布者: 作者: 转发
评论区 (0)
U