1.2 自注意力机制详解 — Transformers核心原理 本节导读:深入理解自注意力机制的数学原理、实现细节和优化技巧,掌握这一Transformers核心组件的工作机制和应用场景。 学习目标 掌握自注意力机制的数学表达和计算过程 理解Query、Key、Value三个向量的作用和区别 了解多头注意力的设计思想和实现方法 掌握注意力掩码的应用场景和实现方式 理解自注意力在序列处理中的优势 核心概念 自注意力机制概述 自注意力机制是Transformers架构的核心创新,它允许序列中的每个元素直接与其他所有元素进行交互,从而捕捉序列内部的复杂关系。与传统的循环神经网络不同,自注意力能够并行处理整个序列,有效解决了长距离依赖问题。
本节导读:深入理解自注意力机制的数学原理、实现细节和优化技巧,掌握这一Transformers核心组件的工作机制和应用场景。
自注意力机制是Transformers架构的核心创新,它允许序列中的每个元素直接与其他所有元素进行交互,从而捕捉序列内部的复杂关系。与传统的循环神经网络不同,自注意力能够并行处理整个序列,有效解决了长距离依赖问题。
自注意力机制基于点积相似度计算,主要包括三个步骤:
数学表达式:
Attention(Q, K, V) = softmax(QK^T / √dk)V
其中:
# 核心依赖 torch==2.1.0 transformers==4.35.2 numpy==1.24.3 matplotlib==3.7.2 seaborn==0.12.2 scikit-learn==1.3.0
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Optional, Tuple import math import time # 设置随机种子保证可重现性 torch.manual_seed(42) np.random.seed(42) print(f"PyTorch版本: {torch.__version__}") print(f"CUDA可用: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA设备: {torch.cuda.get_device_name(0)}") device = torch.device('cuda') else: device = torch.device('cpu')
class SelfAttention(nn.Module): """基础自注意力实现""" def __init__(self, d_model: int, dropout: float = 0.1): super(SelfAttention, self).__init__() self.d_model = d_model # 线性变换层 self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) # 缩放因子 self.scale = math.sqrt(d_model) # Dropout层 self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ 前向传播 Args: x: 输入张量 [batch_size, seq_len, d_model] mask: 可选的掩码 [batch_size, seq_len, seq_len] Returns: output: 输出张量 [batch_size, seq_len, d_model] attention_weights: 注意力权重 [batch_size, seq_len, seq_len] """ batch_size, seq_len, d_model = x.shape # 计算Q、K、V Q = self.query(x) # [batch_size, seq_len, d_model] K = self.key(x) # [batch_size, seq_len, d_model] V = self.value(x) # [batch_size, seq_len, d_model] # 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # 应用掩码 if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 计算注意力权重 attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # 计算输出 output = torch.matmul(attention_weights, V) output = self.out(output) return output, attention_weights def create_causal_mask(seq_len: int) -> torch.Tensor: """创建因果掩码(用于解码器)""" mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1) return mask def create_padding_mask(sequences: torch.Tensor) -> torch.Tensor: """创建填充掩码""" return (sequences != 0).unsqueeze(1).unsqueeze(2) # 测试基础自注意力 def test_basic_attention(): d_model = 512 seq_len = 10 batch_size = 32 # 创建随机输入 x = torch.randn(batch_size, seq_len, d_model).to(device) # 创建自注意力层 attention = SelfAttention(d_model).to(device) # 前向传播 output, attention_weights = attention(x) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"注意力权重形状: {attention_weights.shape}") # 验证注意力权重是否正确归一化 assert torch.allclose(attention_weights.sum(dim=-1), torch.ones_like(attention_weights.sum(dim=-1))) return output, attention_weights # 运行测试 test_basic_attention()
class MultiHeadAttention(nn.Module): """多头注意力实现""" def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model必须能被num_heads整除" self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # 线性变换层 self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) # 缩放因子 self.scale = math.sqrt(self.head_dim) # Dropout层 self.dropout = nn.Dropout(dropout) def split_heads(self, x: torch.Tensor) -> torch.Tensor: """分割输入到多个注意力头""" batch_size, seq_len, _ = x.shape x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) return x.transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim] def combine_heads(self, x: torch.Tensor) -> torch.Tensor: """合并多个注意力头的输出""" batch_size, _, seq_len, _ = x.shape x = x.transpose(1, 2) # [batch_size, seq_len, num_heads, head_dim] return x.contiguous().view(batch_size, seq_len, self.d_model) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = x.shape[0] # 计算Q、K、V Q = self.query(x) # [batch_size, seq_len, d_model] K = self.key(x) # [batch_size, seq_len, d_model] V = self.value(x) # [batch_size, seq_len, d_model] # 分割到多个注意力头 Q = self.split_heads(Q) # [batch_size, num_heads, seq_len, head_dim] K = self.split_heads(K) # [batch_size, num_heads, seq_len, head_dim] V = self.split_heads(V) # [batch_size, num_heads, seq_len, head_dim] # 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # 应用掩码 if mask is not None: # 对于多头注意力,掩码需要扩展 if mask.dim() == 3: # [batch_size, seq_len, seq_len] mask = mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len] scores = scores.masked_fill(mask == 0, -1e9) # 计算注意力权重 attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # 计算上下文向量 context = torch.matmul(attention_weights, V) # 合并注意力头 context = self.combine_heads(context) # [batch_size, seq_len, d_model] # 输出线性变换 output = self.out(context) return output, attention_weights def test_multihead_attention(): d_model = 512 num_heads = 8 seq_len = 16 batch_size = 8 # 创建随机输入 x = torch.randn(batch_size, seq_len, d_model).to(device) # 创建多头注意力层 mha = MultiHeadAttention(d_model, num_heads).to(device) # 前向传播 output, attention_weights = mha(x) print(f"多头注意力测试:") print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"注意力权重形状: {attention_weights.shape}") print(f"注意力头数: {num_heads}") print(f"每个注意力头维度: {d_model // num_heads}") # 验证输出维度 assert output.shape == (batch_size, seq_len, d_model) assert attention_weights.shape == (batch_size, num_heads, seq_len, seq_len) return output, attention_weights # 运行测试 test_multihead_attention()
def visualize_attention(attention_weights: torch.Tensor, tokens: list = None, title: str = "注意力权重可视化"): """可视化注意力权重""" # 如果没有提供tokens,生成默认的token序列 if tokens is None: seq_len = attention_weights.shape[2] tokens = [f"T{i}" for i in range(seq_len)] # 取第一个样本和第一个注意力头 if attention_weights.dim() == 4: # 多头注意力 attention_weights = attention_weights[0, 0] # [seq_len, seq_len] else: # 单头注意力 attention_weights = attention_weights[0] # [seq_len, seq_len] plt.figure(figsize=(10, 8)) # 使用热力图可视化 sns.heatmap(attention_weights.cpu().numpy(), xticklabels=tokens, yticklabels=tokens, cmap='YlOrRd', annot=True, fmt='.2f', cbar=True) plt.title(title) plt.xlabel('Key Tokens') plt.ylabel('Query Tokens') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.show() # 创建示例数据用于可视化 def create_sample_data(): d_model = 64 seq_len = 6 batch_size = 1 num_heads = 4 # 创建模拟输入 x = torch.randn(batch_size, seq_len, d_model).to(device) # 创建多头注意力 mha = MultiHeadAttention(d_model, num_heads).to(device) # 前向传播 output, attention_weights = mha(x) # 模拟token名称 tokens = ["CLS", "你好", "世界", "!", "吗", "SEP"] return attention_weights, tokens # 运行可视化 attention_weights, tokens = create_sample_data() visualize_attention(attention_weights, tokens, "多头注意力权重示例")
class MaskedMultiHeadAttention(nn.Module): """支持多种掩码的多头注意力""" def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1): super(MaskedMultiHeadAttention, self).__init__() self.mha = MultiHeadAttention(d_model, num_heads, dropout) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ 前向传播 Args: query: 查询张量 [batch_size, query_len, d_model] key: 键张量 [batch_size, key_len, d_model] value: 值张量 [batch_size, key_len, d_model] mask: 掩码 [batch_size, query_len, key_len] 或 [batch_size, seq_len, seq_len] """ # 标准化输入维度 if query is key is value: # 自注意力 return self.mha(query, mask) else: # 交叉注意力 batch_size, query_len, d_model = query.shape batch_size, key_len, d_model = key.shape # 扩展输入维度以匹配mha的期望格式 Q = query K = key V = value # 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model) # 应用掩码 if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 计算注意力权重 attention_weights = F.softmax(scores, dim=-1) # 计算输出 output = torch.matmul(attention_weights, V) return output, attention_weights def create_combined_mask(decoder_input: torch.Tensor) -> torch.Tensor: """创建解码器的组合掩码(因果掩码 + 填充掩码)""" batch_size, seq_len = decoder_input.shape # 填充掩码 padding_mask = create_padding_mask(decoder_input) # 因果掩码 causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1) causal_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1) # 组合掩码 combined_mask = padding_mask & (causal_mask == 0) return combined_mask # 测试不同类型的掩码 def test_masking(): d_model = 512 num_heads = 8 batch_size = 2 # 创建输入序列 seq_len = 8 x = torch.randint(0, 100, (batch_size, seq_len, d_model)).to(device) # 创建多头注意力 mha = MaskedMultiHeadAttention(d_model, num_heads).to(device) # 测试无掩码 output, attention_weights = mha(x, x, x) print(f"无掩码输出形状: {output.shape}") # 测试填充掩码 padding_mask = create_padding_mask(x) output_masked, attention_weights_masked = mha(x, x, x, padding_mask) print(f"填充掩码输出形状: {output_masked.shape}") # 比较注意力权重 print(f"平均注意力权重差: {(attention_weights - attention_weights_masked).abs().mean().item():.6f}") return output, attention_weights # 运行掩码测试 test_masking()
A:多头注意力允许模型从不同的表示子空间中学习不同的模式。每个注意力头可以关注不同的语义特征和位置关系,类似于人类从不同角度理解文本。实验表明,多头注意力能够捕获更丰富的语义信息,显著提升模型性能。
A:除以√dk(缩放点积注意力)是为了防止点积值过大导致softmax函数进入梯度饱和区域。当dk很大时,Q和K的点积可能会变得非常大,使得softmax输出变得过于尖锐,影响梯度流动。缩放操作有助于稳定训练过程。
A:因果掩码确保解码器在生成第t个token时只能关注到位置1到t-1的token,而不能关注到位置t及以后的token。这防止了模型"偷看"未来信息,符合自回归生成的特性,使模型能够真正地学习序列的生成过程。
A:使用sin和cos函数的组合允许模型学习相对位置信息。通过不同频率的三角函数组合,模型可以学习到位置之间的相对距离。更重要的是,这种实现允许通过相对位置偏移来编码位置,而不需要为每个可能的位置编码存储单独的向量。
A:标准的Transformer实现中,残差连接在注意力计算和前馈网络计算之后,层归一化在残差连接之后。这种顺序(残差连接 → 层归一化)被称为"post-LN"实现。也有研究提出"pre-LN"实现(层归一化 → 残差连接),在某些情况下训练更稳定。
本节深入探讨了自注意力机制的数学原理、实现细节和应用场景。通过本节的学习,读者应该掌握了:
下一节将介绍位置编码与层归一化的详细实现,完善对Transformers基础组件的理解。
关键词:自注意力, Query-Key-Value, 多头注意力, 注意力掩码, 因果掩码, 点积注意力, 注意力权重
难度:进阶
预计阅读:90分钟