Python中的生成对抗网络(GAN)训练稳定性:WGAN、LSGAN的损失函数与梯度惩罚

Python中的生成对抗网络(GAN)训练稳定性:WGAN、LSGAN的损失函数与梯度惩罚

大家好,今天我们来深入探讨一下生成对抗网络(GAN)的训练稳定性问题,以及WGAN和LSGAN如何通过修改损失函数和引入梯度惩罚来解决这些问题。GANs虽然在生成逼真数据方面表现出色,但其训练过程以不稳定著称,容易出现模式崩溃、梯度消失等问题。

GAN训练不稳定的原因

GAN的训练本质上是一个minimax博弈,生成器(Generator,G)试图生成逼真的假数据,而判别器(Discriminator,D)试图区分真实数据和假数据。这种对抗性的训练方式容易导致以下问题:

  • 模式崩溃(Mode Collapse): 生成器只学会生成特定几种类型的样本,而忽略了真实数据分布的其他部分。
  • 梯度消失(Vanishing Gradients): 判别器过于强大,能够轻松区分真实数据和假数据,导致生成器得到的梯度信息非常小,无法有效更新。
  • 梯度爆炸(Exploding Gradients): 训练过程中梯度变得非常大,导致训练不稳定。
  • 非凸优化问题: GAN的优化目标是非凸的,这意味着存在许多局部最小值,训练容易陷入局部最优解。

这些问题源于原始GAN使用的损失函数(基于JS散度)以及训练过程中的一些内在特性。

WGAN:Wasserstein距离与Earth Mover’s Distance

WGAN (Wasserstein GAN) 提出了使用Wasserstein距离(也称为Earth Mover’s Distance, EMD)来替代原始GAN中的JS散度。Wasserstein距离的定义如下:

W(P_r, P_g) = inf_{gamma sim Pi(P_r, P_g)} E_{(x, y) sim gamma} [||x – y||]

其中P_r是真实数据分布,P_g是生成器生成的数据分布,Pi(P_r, P_g)是所有以P_r和P_g为边缘分布的联合分布gamma的集合。直观地说,Wasserstein距离衡量了将P_r“移动”到P_g所需的最小“工作量”。

为什么Wasserstein距离更好?

  • 连续性: 即使P_r和P_g完全不重叠,Wasserstein距离仍然是连续的,这意味着判别器可以提供有意义的梯度信息给生成器。而JS散度在不重叠的情况下是恒定的,无法提供梯度信息。
  • 梯度信息: Wasserstein距离提供了更可靠的梯度信息,有助于生成器更好地学习。

WGAN的损失函数

由于直接计算Wasserstein距离很困难,WGAN通过Kantorovitch-Rubinstein对偶性将其转化为一个可以计算的形式:

W(P_r, P_g) = sup_{||f||_L <= 1} E_{x sim P_r}[f(x)] – E_{x sim P_g}[f(x)]

其中f是一个Lipschitz连续函数,||f||_L <= 1 表示f的Lipschitz常数为1。在WGAN中,我们用一个判别器(critic)来近似这个Lipschitz连续函数,并通过训练来最大化上述表达式。

因此,WGAN的损失函数如下:

  • 判别器(critic)损失: L_D = E_{x sim P_g}[D(x)] – E_{x sim P_r}[D(x)]
  • 生成器损失: L_G = -E_{x sim P_g}[D(x)]

权重裁剪(Weight Clipping)

为了保证判别器D是Lipschitz连续的,WGAN最初使用权重裁剪(Weight Clipping)的方法,将判别器的权重限制在一个范围内,例如[-c, c]。

WGAN的代码示例(PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, img_size),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        return img

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# 超参数
latent_dim = 100
img_size = 784  # 例如 MNIST 数据集
lr = 0.00005
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64
clip_value = 0.01 # 权重裁剪参数
n_critic = 5 # 每次更新生成器之前更新判别器的次数

# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)

# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 优化器
optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)

# 训练循环
for epoch in range(n_epochs):
    for i in range(len(train_loader) // batch_size): # Assuming train_loader is defined elsewhere
        # ---------------------
        # 训练判别器
        # ---------------------

        for _ in range(n_critic):
            real_imgs = next(iter(train_loader))[0].view(batch_size, -1).to(device) # Assuming train_loader returns (images, labels)
            optimizer_D.zero_grad()

            # 判别器对真实图像的输出
            real_validity = discriminator(real_imgs)

            # 生成假图像
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)

            # 判别器对假图像的输出
            fake_validity = discriminator(fake_imgs)

            # 计算判别器损失
            d_loss = torch.mean(fake_validity) - torch.mean(real_validity)
            d_loss.backward()
            optimizer_D.step()

            # 权重裁剪
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)

        # -----------------
        # 训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 生成假图像
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        # 生成器希望判别器将假图像判定为真
        fake_validity = discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)

        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(train_loader) // batch_size, d_loss.item(), g_loss.item())
        )

权重裁剪的缺点

权重裁剪虽然简单,但会带来一些问题:

  • 梯度消失: 如果裁剪范围太小,判别器的权重会被限制在很小的范围内,导致梯度消失。
  • 梯度爆炸: 如果裁剪范围太大,判别器可能无法满足Lipschitz约束,导致梯度爆炸。
  • 网络容量降低: 强制权重位于特定范围会限制判别器的表达能力。

WGAN-GP:梯度惩罚(Gradient Penalty)

WGAN-GP (WGAN with Gradient Penalty) 提出了使用梯度惩罚来替代权重裁剪,以保证判别器满足Lipschitz约束。

Lipschitz约束与梯度

一个函数f是Lipschitz连续的,如果存在一个常数K,使得对于任意的x和y,满足:

|f(x) – f(y)| <= K * ||x – y||

对于可微函数,Lipschitz约束等价于:

||∇f(x)|| <= K

这意味着函数f的梯度的范数必须小于等于一个常数K。在WGAN-GP中,我们希望判别器的梯度范数接近于1,以保证其满足1-Lipschitz约束。

梯度惩罚的计算

WGAN-GP在判别器的损失函数中增加了一个梯度惩罚项,鼓励判别器的梯度范数接近于1。梯度惩罚的计算方式如下:

  1. 随机插值: 在真实数据和生成数据之间进行随机插值:x_hat = epsilon x + (1 – epsilon) G(z),其中epsilon是一个[0, 1]之间的随机数。
  2. 计算梯度: 计算判别器在插值点x_hat处的梯度:gradients = ∇D(x_hat)。
  3. 计算梯度范数: 计算梯度的L2范数:grad_norm = ||gradients||_2。
  4. 计算梯度惩罚: 计算梯度惩罚项:gradient_penalty = lambda * (grad_norm – 1)^2,其中lambda是一个超参数,控制梯度惩罚的强度。

WGAN-GP的损失函数

  • 判别器损失: L_D = E_{x sim P_g}[D(x)] – E_{x sim P_r}[D(x)] + gradient_penalty
  • 生成器损失: L_G = -E_{x sim P_g}[D(x)]

WGAN-GP的代码示例(PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad

# 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, img_size),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        return img

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples, device):
    """Calculates the gradient penalty loss for WGAN-GP"""
    # Random weight term for interpolation
    alpha = torch.rand((real_samples.size(0), 1)).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones((real_samples.size(0), 1)).to(device)
    # Get gradient w.r.t. interpolates
    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# 超参数
latent_dim = 100
img_size = 784  # 例如 MNIST 数据集
lr = 0.0001
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64
lambda_gp = 10 # 梯度惩罚系数
n_critic = 5 # 每次更新生成器之前更新判别器的次数

# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)

# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 训练循环
for epoch in range(n_epochs):
    for i in range(len(train_loader) // batch_size): # Assuming train_loader is defined elsewhere
        # ---------------------
        # 训练判别器
        # ---------------------

        for _ in range(n_critic):
            real_imgs = next(iter(train_loader))[0].view(batch_size, -1).to(device) # Assuming train_loader returns (images, labels)
            optimizer_D.zero_grad()

            # 判别器对真实图像的输出
            real_validity = discriminator(real_imgs)

            # 生成假图像
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)

            # 判别器对假图像的输出
            fake_validity = discriminator(fake_imgs)

            # 计算梯度惩罚
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data, device)

            # 计算判别器损失
            d_loss = torch.mean(fake_validity) - torch.mean(real_validity) + lambda_gp * gradient_penalty
            d_loss.backward()
            optimizer_D.step()

        # -----------------
        # 训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 生成假图像
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        # 生成器希望判别器将假图像判定为真
        fake_validity = discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)

        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(train_loader) // batch_size, d_loss.item(), g_loss.item())
        )

WGAN-GP的优点

  • 更好的稳定性: 梯度惩罚比权重裁剪更稳定,能够更好地保证判别器满足Lipschitz约束。
  • 更高的图像质量: WGAN-GP通常能够生成更高质量的图像。

LSGAN:最小二乘GAN(Least Squares GAN)

LSGAN (Least Squares GAN) 提出了使用最小二乘损失函数来替代原始GAN中的sigmoid交叉熵损失函数。

原始GAN的损失函数的问题

原始GAN使用sigmoid交叉熵损失函数,这会导致以下问题:

  • 梯度消失: 当生成器生成的样本非常差时,判别器很容易将其判定为假,导致生成器得到的梯度信息非常小,无法有效更新。
  • 对离群点的惩罚: 原始GAN对离群点的惩罚过于严厉,导致生成器难以学习真实数据分布的尾部。

LSGAN的损失函数

LSGAN使用最小二乘损失函数,这可以缓解上述问题。LSGAN的损失函数如下:

  • 判别器损失: L_D = 0.5 E_{x sim P_r}[(D(x) – b)^2] + 0.5 E_{x sim P_g}[(D(x) – a)^2]
  • 生成器损失: L_G = 0.5 * E_{x sim P_g}[(D(x) – c)^2]

其中a, b, c是超参数,通常设置为a = 0, b = 1, c = 1。这意味着判别器希望将真实数据判定为1,将假数据判定为0,而生成器希望将假数据判定为1。

LSGAN的代码示例(PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim, img_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, img_size),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        return img

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid() # 输出范围 [0, 1]
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# 超参数
latent_dim = 100
img_size = 784  # 例如 MNIST 数据集
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_epochs = 50
batch_size = 64

# 初始化生成器和判别器
generator = Generator(latent_dim, img_size)
discriminator = Discriminator(img_size)

# 使用CUDA(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 损失函数
criterion = nn.MSELoss()

# 目标标签
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)

# 训练循环
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(train_loader): # Assuming train_loader is defined elsewhere
        # ---------------------
        # 训练判别器
        # ---------------------

        optimizer_D.zero_grad()

        # 判别器对真实图像的输出
        real_imgs = imgs.view(batch_size, -1).to(device)
        real_pred = discriminator(real_imgs)
        d_real_loss = criterion(real_pred, valid)

        # 生成假图像
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        # 判别器对假图像的输出
        fake_pred = discriminator(fake_imgs)
        d_fake_loss = criterion(fake_pred, fake)

        # 计算判别器损失
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # -----------------
        # 训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 生成假图像
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        # 生成器希望判别器将假图像判定为真
        fake_pred = discriminator(fake_imgs)
        g_loss = criterion(fake_pred, valid)

        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(train_loader), d_loss.item(), g_loss.item())
        )

LSGAN的优点

  • 更稳定的训练: 最小二乘损失函数可以缓解梯度消失问题,提高训练稳定性。
  • 更好的图像质量: LSGAN通常能够生成更高质量的图像,并且能够生成更逼真的数据分布。

总结:不同GAN变体解决训练稳定性的方法

GAN变体 损失函数修改 梯度惩罚/其他约束 优点 缺点
WGAN Wasserstein距离(通过Kantorovitch-Rubinstein对偶性转换) 权重裁剪 解决了JS散度不连续的问题,提供了更可靠的梯度信息,训练更稳定。 权重裁剪可能导致梯度消失或爆炸,限制了判别器的表达能力。
WGAN-GP Wasserstein距离 梯度惩罚,鼓励判别器的梯度范数接近于1 比权重裁剪更稳定,能够更好地保证判别器满足Lipschitz约束,通常生成更高质量的图像。 需要调整梯度惩罚系数,计算量相对较大。
LSGAN 最小二乘损失函数 缓解梯度消失问题,提高训练稳定性,通常能够生成更高质量的图像。 可能对离群点更敏感。

选择合适的GAN变体

选择合适的GAN变体取决于具体的应用场景和数据集。

  • WGAN: 如果计算资源有限,可以尝试WGAN,但需要仔细调整权重裁剪参数。
  • WGAN-GP: 如果计算资源充足,建议使用WGAN-GP,因为它通常能够生成更高质量的图像,并且训练更稳定。
  • LSGAN: 如果数据集存在离群点,或者需要生成更逼真的数据分布,可以尝试LSGAN。

总的来说,WGAN-GP是目前应用最广泛的GAN变体之一,因为它在稳定性和图像质量方面都表现出色。但是,不同的GAN变体都有其优缺点,需要根据具体情况进行选择。

其他提高GAN训练稳定性的技巧

除了修改损失函数和引入梯度惩罚之外,还有一些其他的技巧可以提高GAN的训练稳定性:

  • 使用更好的优化器: Adam是常用的优化器,但有时可以使用其他的优化器,例如RMSprop或SGD,来提高训练稳定性。
  • 调整学习率: 合适的学习率对于GAN的训练至关重要。可以使用学习率衰减策略,或者使用自适应学习率优化器,例如AdamW。
  • 使用批量归一化(Batch Normalization): 批量归一化可以加速训练,提高稳定性。
  • 使用谱归一化(Spectral Normalization): 谱归一化可以限制判别器的Lipschitz常数,提高训练稳定性。
  • 使用标签平滑(Label Smoothing): 标签平滑可以缓解判别器过于自信的问题,提高训练稳定性。
  • 使用更大的批量大小(Batch Size): 更大的批量大小可以提供更稳定的梯度估计,提高训练稳定性。
  • 使用更深的网络结构: 更深的网络结构可以提高生成器和判别器的表达能力,但同时也需要更多的计算资源。
  • 使用更复杂的网络结构: 可以使用更复杂的网络结构,例如残差网络(ResNet)或密集连接网络(DenseNet),来提高生成器和判别器的表达能力。

总结一下关键点:GAN训练稳定性的关键因素

GAN的训练稳定性是一个复杂的问题,涉及到损失函数、优化算法、网络结构等多个方面。WGAN、WGAN-GP和LSGAN通过修改损失函数和引入梯度惩罚等方法,有效地提高了GAN的训练稳定性,使得GAN能够生成更高质量的图像。选择合适的GAN变体和训练技巧,对于成功训练GAN至关重要。

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注