Attention 机制完全指南:从原理到 Transformer 的演进 技术背景 Attention 机制是深度学习领域最重要的创新之一,它彻底改变了自然语言处理的格局。从 Seq2Seq 模型到 Transformer,Attention 机制已经成为现代大语言模型的核心组件。 Attention 机制基础 什么是 Attention Attention 机制允许模型在处理输入序列时,动态地关注不同部分的信息: Self-Attention Self-Attention 是 Attention 机制的一种特殊形式,Query、Key、Value 都来自同一个序列: Multi-Head Attention Multi-Head Attention 允许模型同时关注不同的表示子空间:
Attention 机制是深度学习领域最重要的创新之一,它彻底改变了自然语言处理的格局。从 Seq2Seq 模型到 Transformer,Attention 机制已经成为现代大语言模型的核心组件。
Attention 机制允许模型在处理输入序列时,动态地关注不同部分的信息:
import torch import torch.nn as nn import torch.nn.functional as F class SimpleAttention(nn.Module): def __init__(self, hidden_size): super().__init__() self.attention = nn.Linear(hidden_size, 1) def forward(self, encoder_outputs, decoder_hidden): """ encoder_outputs: (batch_size, seq_len, hidden_size) decoder_hidden: (batch_size, hidden_size) """ # 计算注意力分数 energy = torch.tanh(self.attention(encoder_outputs)) # energy: (batch_size, seq_len, 1) attention_scores = F.softmax(energy, dim=1) # attention_scores: (batch_size, seq_len, 1) # 加权求和 context_vector = torch.sum( encoder_outputs * attention_scores, dim=1 ) # context_vector: (batch_size, hidden_size) return context_vector, attention_scores
Self-Attention 是 Attention 机制的一种特殊形式,Query、Key、Value 都来自同一个序列:
class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super().__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert (self.heads * self.head_dim == embed_size), \ "Embed size needs to be divisible by heads" # 线性变换 self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) # 输出投影 self.fc_out = nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, query, mask): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 分割成多个头 values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) queries = query.reshape(N, query_len, self.heads, self.head_dim) # 计算注意力 # QK^T energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 缩放 energy = energy / (self.embed_size ** (1/2)) # 应用 mask(可选) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) # Softmax attention = torch.softmax(energy, dim=3) # Attention * V out = torch.einsum("nhql,nlhd->nqhd", [attention, values]) # 拼接所有头 out = out.reshape(N, query_len, self.heads * self.head_dim) # 输出投影 out = self.fc_out(out) return out
Multi-Head Attention 允许模型同时关注不同的表示子空间:
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads, dropout=0.1): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # 线性变换并分头 Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 计算注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # 应用 mask if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # 应用注意力权重 context = torch.matmul(attention_weights, V) # 拼接头 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 输出投影 output = self.W_o(context) return output, attention_weights
import math import torch import torch.nn as nn import torch.nn.functional as F class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() # 创建位置编码矩阵 pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] class FeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.linear2(self.dropout(F.relu(self.linear1(x)))) class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(d_model, num_heads, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = FeedForward(d_model, d_ff, dropout) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Multi-Head Attention + Residual + LayerNorm attn_output, _ = self.attention(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # Feed Forward + Residual + LayerNorm ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x class Transformer(nn.Module): def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.dropout = nn.Dropout(dropout) self.fc_out = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): # Embedding + Positional Encoding x = self.embedding(x) x = self.positional_encoding(x) x = self.dropout(x) # Transformer Blocks for block in self.transformer_blocks: x = block(x, mask) # Output projection x = self.fc_out(x) return x # 创建模型 model = Transformer( vocab_size=10000, d_model=512, num_heads=8, num_layers=6, d_ff=2048 ) # 前向传播 input_ids = torch.randint(0, 10000, (32, 128)) # batch_size=32, seq_len=128 output = model(input_ids) print(output.shape) # torch.Size([32, 128, 10000])
class EncoderDecoderTransformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048): super().__init__() # Encoder self.encoder_embedding = nn.Embedding(src_vocab_size, d_model) self.encoder_pos_encoding = PositionalEncoding(d_model) self.encoder_layers = nn.ModuleList([ TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers) ]) # Decoder self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model) self.decoder_pos_encoding = PositionalEncoding(d_model) self.decoder_layers = nn.ModuleList([ DecoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers) ]) self.fc_out = nn.Linear(d_model, tgt_vocab_size) def encode(self, src, src_mask): src = self.encoder_embedding(src) src = self.encoder_pos_encoding(src) for layer in self.encoder_layers: src = layer(src, src_mask) return src def decode(self, tgt, encoder_output, tgt_mask, src_tgt_mask): tgt = self.decoder_embedding(tgt) tgt = self.decoder_pos_encoding(tgt) for layer in self.decoder_layers: tgt = layer(tgt, encoder_output, tgt_mask, src_tgt_mask) return tgt def forward(self, src, tgt, src_mask, tgt_mask, src_tgt_mask): encoder_output = self.encode(src, src_mask) decoder_output = self.decode(tgt, encoder_output, tgt_mask, src_tgt_mask) output = self.fc_out(decoder_output) return output class DecoderBlock(nn.Module): def __init__(self, d_model, num_heads, d_ff): super().__init__() self.self_attention = MultiHeadAttention(d_model, num_heads) self.cross_attention = MultiHeadAttention(d_model, num_heads) self.feed_forward = FeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) def forward(self, x, encoder_output, tgt_mask, src_tgt_mask): # Self-Attention attn_output, _ = self.self_attention(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_output)) # Cross-Attention attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_tgt_mask) x = self.norm2(x + self.dropout(attn_output)) # Feed Forward ff_output = self.feed_forward(x) x = self.norm3(x + self.dropout(ff_output)) return x
class TransformerClassifier(nn.Module): def __init__(self, vocab_size, num_classes, d_model=512, num_heads=8, num_layers=6, dropout=0.1): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(d_model, num_classes) def forward(self, x, mask=None): x = self.embedding(x) x = self.positional_encoding(x) x = self.dropout(x) for block in self.transformer_blocks: x = block(x, mask) # 使用 [CLS] token 的输出进行分类 cls_output = x[:, 0, :] logits = self.classifier(cls_output) return logits # 训练 model = TransformerClassifier(vocab_size=10000, num_classes=10) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) for epoch in range(num_epochs): for batch_x, batch_y in train_loader: optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs, batch_y) loss.backward() optimizer.step()
def translate_sentence(model, sentence, src_vocab, tgt_vocab, max_length=50): model.eval() # Tokenize src_tokens = src_vocab.encode(sentence) src_tensor = torch.LongTensor(src_tokens).unsqueeze(0) # Create masks src_mask = (src_tensor != src_vocab.pad_token_id).unsqueeze(1).unsqueeze(2) # Encode with torch.no_grad(): encoder_output = model.encode(src_tensor, src_mask) # Decode tgt_tokens = [tgt_vocab.bos_token_id] tgt_tensor = torch.LongTensor(tgt_tokens).unsqueeze(0) for _ in range(max_length): tgt_mask = (tgt_tensor != tgt_vocab.pad_token_id).unsqueeze(1).unsqueeze(2) tgt_mask = tgt_mask & torch.tril(torch.ones_like(tgt_mask)).bool() with torch.no_grad(): output = model.decode(tgt_tensor, encoder_output, tgt_mask, src_mask) predictions = model.fc_out(output) # Get next token next_token = predictions[0, -1, :].argmax().item() tgt_tokens.append(next_token) if next_token == tgt_vocab.eos_token_id: break tgt_tensor = torch.LongTensor(tgt_tokens).unsqueeze(0) # Decode translation = tgt_vocab.decode(tgt_tokens) return translation
class TransformerQA(nn.Module): def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model) self.transformer_blocks = nn.ModuleList([ TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers) ]) # Start and end position classifiers self.start_classifier = nn.Linear(d_model, 1) self.end_classifier = nn.Linear(d_model, 1) def forward(self, context, question, mask=None): # Concatenate context and question x = torch.cat([context, question], dim=1) x = self.embedding(x) x = self.positional_encoding(x) for block in self.transformer_blocks: x = block(x, mask) # Predict start and end positions start_logits = self.start_classifier(x).squeeze(-1) end_logits = self.end_classifier(x).squeeze(-1) return start_logits, end_logits
class WarmupScheduler: def __init__(self, optimizer, d_model, warmup_steps=4000): self.optimizer = optimizer self.d_model = d_model self.warmup_steps = warmup_steps self.step_num = 0 def step(self): self.step_num += 1 lr = self.d_model ** (-0.5) * min( self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5) ) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) scheduler = WarmupScheduler(optimizer, d_model=512)
class LabelSmoothingLoss(nn.Module): def __init__(self, vocab_size, smoothing=0.1): super().__init__() self.vocab_size = vocab_size self.smoothing = smoothing self.confidence = 1.0 - smoothing def forward(self, pred, target): pred = F.log_softmax(pred, dim=-1) # Smooth labels with torch.no_grad(): smooth_target = torch.zeros_like(pred) smooth_target.fill_(self.smoothing / (self.vocab_size - 1)) smooth_target.scatter_(1, target.unsqueeze(1), self.confidence) return F.kl_div(pred, smooth_target, reduction='batchmean') criterion = LabelSmoothingLoss(vocab_size=10000, smoothing=0.1)
def train_step(model, batch, optimizer, criterion): model.train() optimizer.zero_grad() src, tgt = batch output = model(src, tgt[:-1], src_mask, tgt_mask) loss = criterion(output.view(-1, output.size(-1)), tgt[1:].view(-1)) loss.backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() return loss.item()
Attention 机制和 Transformer 架构的革命性影响:
✅ 并行化:Self-Attention 允许并行处理序列
✅ 长程依赖:有效捕捉长距离依赖关系
✅ 可扩展性:适合大规模预训练
✅ 通用性:适用于 NLP、CV、多模态等任务
从 Attention Is All You Need 到 GPT-4、Claude 等,Transformer 已经成为现代 AI 的基石。掌握其原理和实现,对于理解和开发大语言模型至关重要。