4.7 对抗生成网络 (Generative Adversarial Networks - GANs) 第四章:PyTorch 高级主题:4.7 对抗生成网络 (GANs) 详解与实践 生成对抗网络 (Generative Adversarial Networks, GANs) 是近年来深度学习领域最激动人心的创新之一。由 Ian Goodfellow 等人在 2014 年提出,GANs 提供了一种全新的框架,用于训练生成模型,尤其擅长生成逼真的图像、音频以及其他复杂数据。与传统的生成模型不同,GANs 采用对抗训练的方式,通过两个神经网络——生成器 (Generator) 和判别器 (Discriminator) 的博弈,不断提升生成模型的性能。 4.7.
生成对抗网络 (Generative Adversarial Networks, GANs) 是近年来深度学习领域最激动人心的创新之一。由 Ian Goodfellow 等人在 2014 年提出,GANs 提供了一种全新的框架,用于训练生成模型,尤其擅长生成逼真的图像、音频以及其他复杂数据。与传统的生成模型不同,GANs 采用对抗训练的方式,通过两个神经网络——生成器 (Generator) 和判别器 (Discriminator) 的博弈,不断提升生成模型的性能。
GANs 的核心思想来源于博弈论中的零和博弈。我们可以将生成器和判别器视为两个玩家,它们在一个共同的环境中进行对抗。
生成器 (Generator, G) 的目标是从一个随机噪声 (通常是高斯分布) 中学习生成与真实数据相似的数据样本。例如,如果我们的目标是生成人脸图像,生成器的任务就是根据输入的随机噪声,生成一张尽可能逼真的人脸图像。
判别器 (Discriminator, D) 的目标是区分输入数据是真实的(来自真实数据集)还是由生成器生成的假数据。判别器本质上是一个二分类器,它接收一个数据样本作为输入,并输出一个概率值,表示该样本是真实数据的概率。
GANs 的训练过程可以被描述为一个极小极大博弈 (Minimax Game):
判别器 D 的目标:最大化区分真实数据和生成数据能力,即尽可能准确地将真实数据判断为真,将生成数据判断为假。这可以用公式表示为最大化以下目标函数:
max_D E_{x~p_{data}(x)} [log D(x)] + E_{z~p_{z}(z)} [log(1 - D(G(z)))]
其中,x 代表真实数据,z 代表随机噪声,G(z) 代表生成器生成的假数据,D(x) 代表判别器判断 x 为真实数据的概率,p_{data}(x) 是真实数据分布,p_{z}(z) 是噪声分布。公式的第一项鼓励判别器将真实数据 x 判断为真 (即 D(x) 接近 1),第二项鼓励判别器将生成数据 G(z) 判断为假 (即 D(G(z)) 接近 0,从而 1 - D(G(z)) 接近 1)。
生成器 G 的目标:最小化判别器区分真实数据和生成数据的能力,即尽可能生成逼真的数据,使得判别器无法区分真假。这可以用公式表示为最小化以下目标函数:
min_G E_{z~p_{z}(z)} [log(1 - D(G(z)))]
生成器的目标是欺骗判别器,让判别器将生成数据 G(z) 误判为真 (即 D(G(z)) 接近 1,从而 1 - D(G(z)) 接近 0)。
通过不断地迭代训练,生成器和判别器在对抗中不断进步。理想情况下,最终生成器能够生成与真实数据分布完全一致的数据,判别器无法区分真假,此时 D(G(z)) 接近 0.5,表示判别器随机猜测真假,GANs 达到纳什均衡。
Graph TD 图示 GANs 的基本结构:
GANs 的训练是一个迭代过程,通常包含以下步骤:
初始化: 初始化生成器 G 和判别器 D 的网络参数。
迭代训练: 在每个训练迭代中,重复以下步骤:
训练判别器 D:
从真实数据集中采样一批真实数据 x。
从噪声分布中采样一批噪声 z。
使用生成器 G 生成一批假数据 G(z)。
将真实数据 x 和假数据 G(z) 输入判别器 D,分别得到 D(x) 和 D(G(z))。
计算判别器损失 Loss_D,根据上述判别器目标函数。
使用梯度下降法更新判别器 D 的参数,最大化 Loss_D。
训练生成器 G:
从噪声分布中采样一批噪声 z。
使用生成器 G 生成一批假数据 G(z)。
将假数据 G(z) 输入判别器 D,得到 D(G(z))。
计算生成器损失 Loss_G,根据上述生成器目标函数。
使用梯度下降法更新生成器 G 的参数,最小化 Loss_G。
重复步骤 2,直到达到训练停止条件(例如,达到预定的迭代次数,或损失函数收敛)。
关键点:
交替训练: 通常采用交替训练的方式,即在一个迭代中先训练判别器,再训练生成器。这样做有助于稳定训练过程,避免一方过强导致另一方无法学习。
损失函数: 选择合适的损失函数至关重要。原始 GANs 使用的是基于交叉熵的损失函数,但后续的研究也提出了许多其他的损失函数,例如 Wasserstein GAN (WGAN) 使用的 Earth Mover's Distance。
优化器: 常用的优化器包括 Adam 和 SGD。对于 GANs,通常使用 Adam 优化器,因为它对超参数不敏感,且收敛速度较快。
网络结构: 生成器和判别器的网络结构需要根据具体任务进行设计。对于图像生成任务,通常使用卷积神经网络 (CNN) 作为生成器和判别器的基本结构。生成器通常采用反卷积 (Transposed Convolution) 或上采样 (Upsampling) 操作,将低维噪声映射到高维图像空间。判别器则通常采用卷积和池化操作,提取图像特征并进行分类。
接下来,我们将使用 PyTorch 实现一个简单的 GAN,用于生成 MNIST 手写数字图像。
1. 环境准备和数据加载
首先,导入必要的 PyTorch 库,并加载 MNIST 数据集。
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torchvision.utils import save_image # 超参数设置 batch_size = 64 latent_dim = 100 # 噪声向量的维度 learning_rate = 0.0002 num_epochs = 50 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # MNIST 数据集加载和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1] ]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
2. 生成器 (Generator) 网络定义
我们定义一个简单的生成器网络,使用线性层和 ReLU 激活函数。生成器的输入是一个维度为 latent_dim 的噪声向量,输出是一个维度为 784 的向量,对应于 28x28 的 MNIST 图像。
class Generator(nn.Module): def __init__(self, latent_dim): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 784), nn.Tanh() # 输出范围 [-1, 1],与数据归一化范围一致 ) def forward(self, z): img = self.model(z) img = img.view(img.size(0), 1, 28, 28) # reshape 成图像格式 return img
3. 判别器 (Discriminator) 网络定义
我们定义一个简单的判别器网络,同样使用线性层和 LeakyReLU 激活函数。判别器的输入是一个维度为 784 的向量(MNIST 图像),输出是一个标量,表示图像为真实数据的概率。
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(784, 512), nn.LeakyReLU(0.2), # LeakyReLU 避免梯度消失 nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() # 输出概率值 [0, 1] ) def forward(self, img): img_flat = img.view(img.size(0), -1) # flatten 成向量 validity = self.model(img_flat) return validity
4. 损失函数和优化器定义
我们使用二元交叉熵损失函数 (Binary Cross Entropy Loss) 和 Adam 优化器。
# 初始化生成器和判别器 generator = Generator(latent_dim).to(device) discriminator = Discriminator().to(device) # 损失函数 loss_func = nn.BCELoss() # 优化器 optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate) optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
5. 训练循环
定义训练循环,按照 GANs 的训练步骤进行迭代。
def train_gan(generator, discriminator, train_loader, optimizer_G, optimizer_D, loss_func, num_epochs, latent_dim, device): for epoch in range(num_epochs): for i, (real_imgs, _) in enumerate(train_loader): real_imgs = real_imgs.to(device) batch_size = real_imgs.size(0) # 真实标签和虚假标签 real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # ----- 训练判别器 ----- optimizer_D.zero_grad() # 真实数据损失 output_real = discriminator(real_imgs) loss_D_real = loss_func(output_real, real_labels) # 生成数据损失 z = torch.randn(batch_size, latent_dim).to(device) # 随机噪声 fake_imgs = generator(z) output_fake = discriminator(fake_imgs.detach()) # detach() 阻止梯度回传到生成器 loss_D_fake = loss_func(output_fake, fake_labels) # 判别器总损失 loss_D = loss_D_real + loss_D_fake loss_D.backward() optimizer_D.step() # ----- 训练生成器 ----- optimizer_G.zero_grad() # 生成器损失 (目标是欺骗判别器) z = torch.randn(batch_size, latent_dim).to(device) fake_imgs = generator(z) output_fake_G = discriminator(fake_imgs) # 再次计算判别器对生成数据的输出,但不 detach() loss_G = loss_func(output_fake_G, real_labels) # 生成器希望判别器将假数据判断为真 loss_G.backward() optimizer_G.step() # 打印训练信息 if (i+1) % 200 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D_loss: {loss_D.item():.4f}, G_loss: {loss_G.item():.4f}') # 每个 epoch 结束时保存生成的图像 if (epoch+1) % 5 == 0: z_sample = torch.randn(64, latent_dim).to(device) fake_imgs_sample = generator(z_sample) save_image(fake_imgs_sample, f'./generated_https://www.aiknowledge.cn/images/PyTorch/epoch_{epoch+1}.png', nrow=8, normalize=True) print(f"Saved generated images at epoch {epoch+1}") # 创建保存生成图像的文件夹 import os os.makedirs('./generated_images', exist_ok=True) # 开始训练 train_gan(generator, discriminator, train_loader, optimizer_G, optimizer_D, loss_func, num_epochs, latent_dim, device) print("Training finished!")
6. 结果可视化
训练完成后,可以在 ./generated_images 文件夹中查看生成的 MNIST 图像。随着训练的进行,生成的图像质量会逐渐提高,从模糊的噪声图像逐渐变为可辨认的数字图像。
代码详解:
数据预处理: MNIST 数据集被归一化到 [-1, 1] 范围,生成器的输出也使用 Tanh 激活函数保证输出范围一致。
网络结构: 生成器和判别器都使用了简单的多层感知机 (MLP) 结构。对于更复杂的图像生成任务,可以考虑使用卷积神经网络 (CNN) 结构,例如 DCGAN (Deep Convolutional GANs)。
损失函数: 使用二元交叉熵损失函数,目标标签 real_labels 为全 1,fake_labels 为全 0。
detach() 操作: 在训练判别器时,使用 fake_imgs.detach() 阻止梯度回传到生成器。这是因为在训练判别器时,我们只希望更新判别器的参数,而不希望影响生成器的参数。
保存生成图像: 每隔几个 epoch 保存生成的图像,可以直观地观察生成器性能的提升过程。
尽管 GANs 在生成模型领域取得了巨大的成功,但其训练过程仍然面临一些挑战:
训练不稳定: GANs 的训练过程可能非常不稳定,容易出现模式崩溃 (Mode Collapse) 和梯度消失等问题。模式崩溃指的是生成器只能生成少数几种样本,而无法覆盖真实数据分布的所有模式。
难以收敛: GANs 的损失函数并非真正意义上的损失函数,它更多地是反映了生成器和判别器之间的博弈状态,难以判断训练是否收敛。
评估困难: 评价生成模型的性能通常比较困难,尤其对于 GANs,生成的样本质量主观性较强,缺乏客观的评价指标。
为了解决这些挑战,研究者们提出了许多 GANs 的改进版本和训练技巧,例如:
Wasserstein GAN (WGAN): 使用 Earth Mover's Distance 作为损失函数,缓解了原始 GANs 的梯度消失问题,提高了训练稳定性。
Deep Convolutional GANs (DCGAN): 将卷积神经网络引入 GANs,并提出了许多训练技巧,例如使用 Batch Normalization、ReLU 和 LeakyReLU 激活函数、去除全连接层等,提高了图像生成质量和训练稳定性。
Conditional GANs (CGAN): 在 GANs 的基础上引入条件信息,例如类别标签,使得生成器可以生成特定类别的样本,增强了生成模型的可控性。
Spectral Normalization GAN (SN-GAN): 使用谱归一化技术约束判别器的 Lipschitz 常数,进一步提高了训练稳定性和生成样本质量。