Python中的生成对抗网络(GAN)训练稳定性:WGAN、LSGAN的损失函数与梯度惩罚
大家好,今天我们来深入探讨一下生成对抗网络(GAN)的训练稳定性问题,以及WGAN和LSGAN如何通过修改损失函数和引入梯度惩罚来解决这些问题。GANs虽然在生成逼真数据方面表现出色,但其训练过程以不稳定著称,容易出现模式崩溃、梯度消失等问题。
GAN训练不稳定的原因
GAN的训练本质上是一个minimax博弈,生成器(Generator,G)试图生成逼真的假数据,而判别器(Discriminator,D)试图区分真实数据和假数据。这种对抗性的训练方式容易导致以下问题:
- 模式崩溃(Mode Collapse): 生成器只学会生成特定几种类型的样本,而忽略了真实数据分布的其他部分。
- 梯度消失(Vanishing Gradients): 判别器过于强大,能够轻松区分真实数据和假数据,导致生成器得到的梯度信息非常小,无法有效更新。
- 梯度爆炸(Exploding Gradients): 训练过程中梯度变得非常大,导致训练不稳定。
- 非凸优化问题: GAN的优化目标是非凸的,这意味着存在许多局部最小值,训练容易陷入局部最优解。
这些问题源于原始GAN使用的损失函数(基于JS散度)以及训练过程中的一些内在特性。
WGAN:Wasserstein距离与Earth Mover’s Distance
WGAN (Wasserstein GAN) 提出了使用Wasserstein距离(也称为Earth Mover’s Distance, EMD)来替代原始GAN中的JS散度。Wasserstein距离的定义如下:
W(P_r, P_g) = inf_{gamma sim Pi(P_r, P_g)} E_{(x, y) sim gamma} [||x – y||]
其中P_r是真实数据分布,P_g是生成器生成的数据分布,Pi(P_r, P_g)是所有以P_r和P_g为边缘分布的联合分布gamma的集合。直观地说,Wasserstein距离衡量了将P_r“移动”到P_g所需的最小“工作量”。
为什么Wasserstein距离更好?
- 连续性: 即使P_r和P_g完全不重叠,Wasserstein距离仍然是连续的,这意味着判别器可以提供有意义的梯度信息给生成器。而JS散度在不重叠的情况下是恒定的,无法提供梯度信息。
- 梯度信息: Wasserstein距离提供了更可靠的梯度信息,有助于生成器更好地学习。
WGAN的损失函数
由于直接计算Wasserstein距离很困难,WGAN通过Kantorovitch-Rubinstein对偶性将其转化为一个可以计算的形式:
W(P_r, P_g) = sup_{||f||_L <= 1} E_{x sim P_r}[f(x)] – E_{x sim P_g}[f(x)]
其中f是一个Lipschitz连续函数,||f||_L <= 1 表示f的Lipschitz常数为1。在WGAN中,我们用一个判别器(critic)来近似这个Lipschitz连续函数,并通过训练来最大化上述表达式。
因此,WGAN的损失函数如下:
- 判别器(critic)损失: L_D = E_{x sim P_g}[D(x)] – E_{x sim P_r}[D(x)]
- 生成器损失: L_G = -E_{x sim P_g}[D(x)]
权重裁剪(Weight Clipping)
为了保证判别器D是Lipschitz连续的,WGAN最初使用权重裁剪(Weight Clipping)的方法,将判别器的权重限制在一个范围内,例如[-c, c]。
WGAN的代码示例(PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
# 定义生成器
class Generator(nn.Module):
def __init__(self, latent_dim, img_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, img_size),
nn.Tanh() # 输出范围 [-1, 1]
)
def forward(self, z):
img = self.model(z)
return img
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, img_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_size, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
def forward(self, img):
validity = self.model(img)
return validity
# 超参数
latent_dim = 100
img_size = 784 # 例如 MNIST 数据集
lr = 0.00005
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64
clip_value = 0.01 # 权重裁剪参数
n_critic = 5 # 每次更新生成器之前更新判别器的次数
# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)
# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
# 优化器
optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)
# 训练循环
for epoch in range(n_epochs):
for i in range(len(train_loader) // batch_size): # Assuming train_loader is defined elsewhere
# ---------------------
# 训练判别器
# ---------------------
for _ in range(n_critic):
real_imgs = next(iter(train_loader))[0].view(batch_size, -1).to(device) # Assuming train_loader returns (images, labels)
optimizer_D.zero_grad()
# 判别器对真实图像的输出
real_validity = discriminator(real_imgs)
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 判别器对假图像的输出
fake_validity = discriminator(fake_imgs)
# 计算判别器损失
d_loss = torch.mean(fake_validity) - torch.mean(real_validity)
d_loss.backward()
optimizer_D.step()
# 权重裁剪
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# -----------------
# 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 生成器希望判别器将假图像判定为真
fake_validity = discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, n_epochs, i, len(train_loader) // batch_size, d_loss.item(), g_loss.item())
)
权重裁剪的缺点
权重裁剪虽然简单,但会带来一些问题:
- 梯度消失: 如果裁剪范围太小,判别器的权重会被限制在很小的范围内,导致梯度消失。
- 梯度爆炸: 如果裁剪范围太大,判别器可能无法满足Lipschitz约束,导致梯度爆炸。
- 网络容量降低: 强制权重位于特定范围会限制判别器的表达能力。
WGAN-GP:梯度惩罚(Gradient Penalty)
WGAN-GP (WGAN with Gradient Penalty) 提出了使用梯度惩罚来替代权重裁剪,以保证判别器满足Lipschitz约束。
Lipschitz约束与梯度
一个函数f是Lipschitz连续的,如果存在一个常数K,使得对于任意的x和y,满足:
|f(x) – f(y)| <= K * ||x – y||
对于可微函数,Lipschitz约束等价于:
||∇f(x)|| <= K
这意味着函数f的梯度的范数必须小于等于一个常数K。在WGAN-GP中,我们希望判别器的梯度范数接近于1,以保证其满足1-Lipschitz约束。
梯度惩罚的计算
WGAN-GP在判别器的损失函数中增加了一个梯度惩罚项,鼓励判别器的梯度范数接近于1。梯度惩罚的计算方式如下:
- 随机插值: 在真实数据和生成数据之间进行随机插值:x_hat = epsilon x + (1 – epsilon) G(z),其中epsilon是一个[0, 1]之间的随机数。
- 计算梯度: 计算判别器在插值点x_hat处的梯度:gradients = ∇D(x_hat)。
- 计算梯度范数: 计算梯度的L2范数:grad_norm = ||gradients||_2。
- 计算梯度惩罚: 计算梯度惩罚项:gradient_penalty = lambda * (grad_norm – 1)^2,其中lambda是一个超参数,控制梯度惩罚的强度。
WGAN-GP的损失函数
- 判别器损失: L_D = E_{x sim P_g}[D(x)] – E_{x sim P_r}[D(x)] + gradient_penalty
- 生成器损失: L_G = -E_{x sim P_g}[D(x)]
WGAN-GP的代码示例(PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
# 定义生成器
class Generator(nn.Module):
def __init__(self, latent_dim, img_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, img_size),
nn.Tanh() # 输出范围 [-1, 1]
)
def forward(self, z):
img = self.model(z)
return img
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, img_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_size, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
def forward(self, img):
validity = self.model(img)
return validity
# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples, device):
"""Calculates the gradient penalty loss for WGAN-GP"""
# Random weight term for interpolation
alpha = torch.rand((real_samples.size(0), 1)).to(device)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = torch.ones((real_samples.size(0), 1)).to(device)
# Get gradient w.r.t. interpolates
gradients = grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
# 超参数
latent_dim = 100
img_size = 784 # 例如 MNIST 数据集
lr = 0.0001
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64
lambda_gp = 10 # 梯度惩罚系数
n_critic = 5 # 每次更新生成器之前更新判别器的次数
# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)
# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 训练循环
for epoch in range(n_epochs):
for i in range(len(train_loader) // batch_size): # Assuming train_loader is defined elsewhere
# ---------------------
# 训练判别器
# ---------------------
for _ in range(n_critic):
real_imgs = next(iter(train_loader))[0].view(batch_size, -1).to(device) # Assuming train_loader returns (images, labels)
optimizer_D.zero_grad()
# 判别器对真实图像的输出
real_validity = discriminator(real_imgs)
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 判别器对假图像的输出
fake_validity = discriminator(fake_imgs)
# 计算梯度惩罚
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, device)
# 计算判别器损失
d_loss = torch.mean(fake_validity) - torch.mean(real_validity) + lambda_gp * gradient_penalty
d_loss.backward()
optimizer_D.step()
# -----------------
# 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 生成器希望判别器将假图像判定为真
fake_validity = discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, n_epochs, i, len(train_loader) // batch_size, d_loss.item(), g_loss.item())
)
WGAN-GP的优点
- 更好的稳定性: 梯度惩罚比权重裁剪更稳定,能够更好地保证判别器满足Lipschitz约束。
- 更高的图像质量: WGAN-GP通常能够生成更高质量的图像。
LSGAN:最小二乘GAN(Least Squares GAN)
LSGAN (Least Squares GAN) 提出了使用最小二乘损失函数来替代原始GAN中的sigmoid交叉熵损失函数。
原始GAN的损失函数的问题
原始GAN使用sigmoid交叉熵损失函数,这会导致以下问题:
- 梯度消失: 当生成器生成的样本非常差时,判别器很容易将其判定为假,导致生成器得到的梯度信息非常小,无法有效更新。
- 对离群点的惩罚: 原始GAN对离群点的惩罚过于严厉,导致生成器难以学习真实数据分布的尾部。
LSGAN的损失函数
LSGAN使用最小二乘损失函数,这可以缓解上述问题。LSGAN的损失函数如下:
- 判别器损失: L_D = 0.5 E_{x sim P_r}[(D(x) – b)^2] + 0.5 E_{x sim P_g}[(D(x) – a)^2]
- 生成器损失: L_G = 0.5 * E_{x sim P_g}[(D(x) – c)^2]
其中a, b, c是超参数,通常设置为a = 0, b = 1, c = 1。这意味着判别器希望将真实数据判定为1,将假数据判定为0,而生成器希望将假数据判定为1。
LSGAN的代码示例(PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
# 定义生成器
class Generator(nn.Module):
def __init__(self, latent_dim, img_size):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, img_size),
nn.Tanh() # 输出范围 [-1, 1]
)
def forward(self, z):
img = self.model(z)
return img
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, img_size):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_size, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid() # 输出范围 [0, 1]
)
def forward(self, img):
validity = self.model(img)
return validity
# 超参数
latent_dim = 100
img_size = 784 # 例如 MNIST 数据集
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64
# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)
# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 损失函数
criterion = nn.MSELoss()
# 目标标签
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# 训练循环
for epoch in range(n_epochs):
for i, (imgs, _) in enumerate(train_loader): # Assuming train_loader is defined elsewhere
# ---------------------
# 训练判别器
# ---------------------
optimizer_D.zero_grad()
# 判别器对真实图像的输出
real_imgs = imgs.view(batch_size, -1).to(device)
real_pred = discriminator(real_imgs)
d_real_loss = criterion(real_pred, valid)
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 判别器对假图像的输出
fake_pred = discriminator(fake_imgs)
d_fake_loss = criterion(fake_pred, fake)
# 计算判别器损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -----------------
# 训练生成器
# -----------------
optimizer_G.zero_grad()
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
# 生成器希望判别器将假图像判定为真
fake_pred = discriminator(fake_imgs)
g_loss = criterion(fake_pred, valid)
g_loss.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, n_epochs, i, len(train_loader), d_loss.item(), g_loss.item())
)
LSGAN的优点
- 更稳定的训练: 最小二乘损失函数可以缓解梯度消失问题,提高训练稳定性。
- 更好的图像质量: LSGAN通常能够生成更高质量的图像,并且能够生成更逼真的数据分布。
总结:不同GAN变体解决训练稳定性的方法
| GAN变体 | 损失函数修改 | 梯度惩罚/其他约束 | 优点 | 缺点 |
|---|---|---|---|---|
| WGAN | Wasserstein距离(通过Kantorovitch-Rubinstein对偶性转换) | 权重裁剪 | 解决了JS散度不连续的问题,提供了更可靠的梯度信息,训练更稳定。 | 权重裁剪可能导致梯度消失或爆炸,限制了判别器的表达能力。 |
| WGAN-GP | Wasserstein距离 | 梯度惩罚,鼓励判别器的梯度范数接近于1 | 比权重裁剪更稳定,能够更好地保证判别器满足Lipschitz约束,通常生成更高质量的图像。 | 需要调整梯度惩罚系数,计算量相对较大。 |
| LSGAN | 最小二乘损失函数 | 无 | 缓解梯度消失问题,提高训练稳定性,通常能够生成更高质量的图像。 | 可能对离群点更敏感。 |
选择合适的GAN变体
选择合适的GAN变体取决于具体的应用场景和数据集。
- WGAN: 如果计算资源有限,可以尝试WGAN,但需要仔细调整权重裁剪参数。
- WGAN-GP: 如果计算资源充足,建议使用WGAN-GP,因为它通常能够生成更高质量的图像,并且训练更稳定。
- LSGAN: 如果数据集存在离群点,或者需要生成更逼真的数据分布,可以尝试LSGAN。
总的来说,WGAN-GP是目前应用最广泛的GAN变体之一,因为它在稳定性和图像质量方面都表现出色。但是,不同的GAN变体都有其优缺点,需要根据具体情况进行选择。
其他提高GAN训练稳定性的技巧
除了修改损失函数和引入梯度惩罚之外,还有一些其他的技巧可以提高GAN的训练稳定性:
- 使用更好的优化器: Adam是常用的优化器,但有时可以使用其他的优化器,例如RMSprop或SGD,来提高训练稳定性。
- 调整学习率: 合适的学习率对于GAN的训练至关重要。可以使用学习率衰减策略,或者使用自适应学习率优化器,例如AdamW。
- 使用批量归一化(Batch Normalization): 批量归一化可以加速训练,提高稳定性。
- 使用谱归一化(Spectral Normalization): 谱归一化可以限制判别器的Lipschitz常数,提高训练稳定性。
- 使用标签平滑(Label Smoothing): 标签平滑可以缓解判别器过于自信的问题,提高训练稳定性。
- 使用更大的批量大小(Batch Size): 更大的批量大小可以提供更稳定的梯度估计,提高训练稳定性。
- 使用更深的网络结构: 更深的网络结构可以提高生成器和判别器的表达能力,但同时也需要更多的计算资源。
- 使用更复杂的网络结构: 可以使用更复杂的网络结构,例如残差网络(ResNet)或密集连接网络(DenseNet),来提高生成器和判别器的表达能力。
总结一下关键点:GAN训练稳定性的关键因素
GAN的训练稳定性是一个复杂的问题,涉及到损失函数、优化算法、网络结构等多个方面。WGAN、WGAN-GP和LSGAN通过修改损失函数和引入梯度惩罚等方法,有效地提高了GAN的训练稳定性,使得GAN能够生成更高质量的图像。选择合适的GAN变体和训练技巧,对于成功训练GAN至关重要。
更多IT精英技术系列讲座,到智猿学院