Python生成对抗网络(GAN):使用PyTorch实现CycleGAN和StyleGAN等高级模型。

Python生成对抗网络(GAN):使用PyTorch实现CycleGAN和StyleGAN等高级模型

大家好,今天我们深入探讨生成对抗网络(GANs)在PyTorch中的高级应用,重点是CycleGAN和StyleGAN的实现。我们将从理论基础出发,逐步构建代码,并分析其核心机制。

1. GANs回顾与挑战

GANs由生成器(Generator)和判别器(Discriminator)组成。生成器的目标是从随机噪声中学习生成逼真的数据,判别器的目标是区分真实数据和生成数据。两者相互对抗,最终达到纳什均衡,生成器能够生成以假乱真的数据。

GANs的训练面临诸多挑战:

  • 模式崩溃(Mode Collapse): 生成器可能只学习生成数据集中的少数几种模式,而忽略其他模式。
  • 训练不稳定(Training Instability): 训练过程中,生成器和判别器可能陷入震荡,导致无法收敛。
  • 梯度消失/爆炸(Vanishing/Exploding Gradients): 在训练的早期或晚期,梯度可能变得非常小或非常大,阻碍学习。

为了克服这些挑战,研究人员提出了各种改进的GANs架构,如CycleGAN和StyleGAN。

2. CycleGAN:无配对图像转换

CycleGAN解决的是无配对图像转换问题。例如,将马的图像转换为斑马的图像,而不需要一一对应的马和斑马的训练图像。

2.1 CycleGAN 原理

CycleGAN的核心思想是引入循环一致性损失(Cycle Consistency Loss)。它由两个GAN组成:

  • G: X -> Y: 将域X的图像转换为域Y的图像。
  • F: Y -> X: 将域Y的图像转换为域X的图像。

除了标准的对抗损失外,CycleGAN还使用循环一致性损失:

  • x -> G(x) -> F(G(x)) ≈ x: 将X域的图像x转换为Y域的图像G(x),再将G(x)转换回X域,应该与原始图像x尽可能相似。
  • y -> F(y) -> G(F(y)) ≈ y: 同理,将Y域的图像y转换为X域的图像F(y),再将F(y)转换回Y域,应该与原始图像y尽可能相似。

这种循环一致性约束确保了图像转换的可逆性,并保持了图像的内容不变。

2.2 CycleGAN 网络架构

CycleGAN通常使用ResNet架构作为生成器,PatchGAN作为判别器。

  • 生成器 (Generator): 通常使用带有跳跃连接的ResNet结构。ResNet块有助于解决梯度消失问题,跳跃连接可以保留图像的细节信息。
  • 判别器 (Discriminator): PatchGAN判别器不是判断整个图像的真假,而是判断图像的每个patch的真假。这使得判别器可以关注图像的局部细节,从而生成更逼真的图像。

2.3 CycleGAN 代码实现 (PyTorch)

首先,定义ResNet块:

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0, bias=False),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0, bias=False),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

然后,定义生成器:

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, kernel_size=7, padding=0, bias=False),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(True) ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [   nn.Conv2d(64 * mult, 64 * mult * 2, kernel_size=3, stride=2, padding=1, bias=False),
                        nn.InstanceNorm2d(64 * mult * 2),
                        nn.ReLU(True) ]

        # Residual blocks
        mult = 2**n_downsampling
        for i in range(n_residual_blocks):
            model += [ResidualBlock(64 * mult)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [   nn.ConvTranspose2d(64 * mult, int(64 * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                        nn.InstanceNorm2d(int(64 * mult / 2)),
                        nn.ReLU(True) ]

        # Output convolution block
        model += [   nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

接下来,定义PatchGAN判别器:

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of layers arranged in a sequential manner
        model = [   nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [   nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                    nn.InstanceNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [   nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
                    nn.InstanceNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [   nn.Conv2d(256, 512, kernel_size=4, padding=1),
                    nn.InstanceNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, kernel_size=4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return torch.sigmoid(x)

最后,定义损失函数和训练循环(简化版):

import torch.optim as optim

# 初始化模型
netG_A2B = Generator(3, 3)
netG_B2A = Generator(3, 3)
netD_A = Discriminator(3)
netD_B = Discriminator(3)

# 定义优化器
optimizer_G = optim.Adam(list(netG_A2B.parameters()) + list(netG_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(netD_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(netD_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 定义损失函数
criterion_GAN = nn.MSELoss() # 使用MSE Loss替代原始的BCE Loss,能使训练更加稳定
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# 训练循环 (简化版)
def train(real_A, real_B):
    # 训练判别器D_A
    optimizer_D_A.zero_grad()
    fake_B = netG_A2B(real_A)
    pred_real = netD_A(real_B)
    pred_fake = netD_A(fake_B.detach()) # .detach() 防止梯度传播到生成器
    loss_D_A_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
    loss_D_A_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
    loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
    loss_D_A.backward()
    optimizer_D_A.step()

    # 训练判别器D_B
    optimizer_D_B.zero_grad()
    fake_A = netG_B2A(real_B)
    pred_real = netD_B(real_A)
    pred_fake = netD_B(fake_A.detach())
    loss_D_B_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
    loss_D_B_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
    loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
    loss_D_B.backward()
    optimizer_D_B.step()

    # 训练生成器
    optimizer_G.zero_grad()

    # GAN loss
    pred_fake_B = netD_A(fake_B)
    loss_GAN_A2B = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))
    fake_A = netG_B2A(real_B)
    pred_fake_A = netD_B(fake_A)
    loss_GAN_B2A = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))

    # Cycle loss
    recovered_A = netG_B2A(fake_B)
    loss_cycle_A = criterion_cycle(recovered_A, real_A)
    recovered_B = netG_A2B(fake_A)
    loss_cycle_B = criterion_cycle(recovered_B, real_B)

    # Identity loss (可选)
    identity_A = netG_B2A(real_A)
    loss_identity_A = criterion_identity(identity_A, real_A) * 5.0 # 论文建议乘以5.0
    identity_B = netG_A2B(real_B)
    loss_identity_B = criterion_identity(identity_B, real_B) * 5.0

    # Total loss
    loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A * 10.0 + loss_cycle_B * 10.0 + loss_identity_A + loss_identity_B # 论文建议 cycle loss 乘以 10.0
    loss_G.backward()
    optimizer_G.step()

    return loss_D_A.item(), loss_D_B.item(), loss_G.item()

2.4 CycleGAN 训练技巧

  • 学习率衰减(Learning Rate Decay): 在训练过程中逐渐降低学习率可以帮助模型更好地收敛。
  • 使用历史生成图像(Use of History Generated Images): 将最近生成的图像存储在一个历史缓冲区中,并在训练判别器时使用这些历史图像。这可以减少模式崩溃。
  • Instance Normalization: 使用 Instance Normalization 代替 Batch Normalization。Instance Normalization 在图像转换任务中表现更好。

3. StyleGAN:控制图像风格的生成器

StyleGAN 是一种专门为生成高分辨率人脸图像而设计的GAN架构。它通过控制生成器的不同层级的风格,可以生成具有不同风格特征的图像。

3.1 StyleGAN 原理

StyleGAN的核心思想是将噪声映射到一个中间隐空间W,然后将W中的风格向量注入到生成器的不同层级。

  • 映射网络(Mapping Network): 将噪声z映射到中间隐空间w。映射网络是一个多层感知机(MLP)。
  • 风格调制(Adaptive Instance Normalization, AdaIN): 将风格向量w注入到生成器的每个卷积层之后。AdaIN通过调整每个通道的均值和方差来控制图像的风格。
  • 噪声输入(Noise Input): 在生成器的每个卷积层之后添加噪声。噪声可以增加图像的细节,并防止生成器记住训练数据。
  • 恒定输入(Constant Input): 生成器从一个恒定的输入开始,而不是从随机噪声开始。

3.2 StyleGAN 网络架构

StyleGAN的生成器由多个卷积层组成,每个卷积层之后都进行风格调制和噪声输入。

  • 映射网络: 通常由8层MLP组成。
  • 生成器: 由多个StyleBlock组成。每个StyleBlock包含一个卷积层、一个AdaIN层和一个噪声输入层。

3.3 StyleGAN 代码实现 (PyTorch)

首先,定义映射网络:

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim, n_layers=8):
        super(MappingNetwork, self).__init__()
        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(z_dim if i == 0 else w_dim, w_dim))
            layers.append(nn.LeakyReLU(0.2))
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        w = self.net(z)
        return w

然后,定义AdaIN层:

class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, num_features, latent_dim):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features)
        self.style_scale_transform = nn.Linear(latent_dim, num_features)
        self.style_shift_transform = nn.Linear(latent_dim, num_features)

    def forward(self, x, w):
        x = self.norm(x)
        style_scale = self.style_scale_transform(w)[:, :, None, None]
        style_shift = self.style_shift_transform(w)[:, :, None, None]
        return style_scale * x + style_shift

接下来,定义StyleBlock:

class StyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, latent_dim, resolution, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.noise_strength = nn.Parameter(torch.zeros([]))
        self.adain = AdaptiveInstanceNorm(out_channels, latent_dim)
        self.noise_injection = NoiseInjection(resolution) # resolution 决定了噪声图的大小

    def forward(self, x, w, noise):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.conv(x)
        x = self.noise_injection(x, noise)
        x = self.adain(x, w)
        x = F.leaky_relu(x, 0.2)
        return x

class NoiseInjection(nn.Module):
    def __init__(self, resolution):
        super().__init__()
        self.noise_strength = nn.Parameter(torch.zeros([])) #可学习的噪声强度
        self.resolution = resolution

    def forward(self, image, noise):
        batch, _, h, w = image.shape
        if noise is None:
            noise = torch.randn(batch, 1, self.resolution, self.resolution, device=image.device).float()
        return image + self.noise_strength * noise

最后,定义生成器:

import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, latent_dim, img_resolution=64, img_channels=3):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_resolution = img_resolution
        self.img_channels = img_channels

        self.mapping_network = MappingNetwork(latent_dim, latent_dim)

        # Initial constant input
        self.constant_input = nn.Parameter(torch.randn(1, 512, 4, 4)) # 512 是初始特征图的通道数, 4x4 是初始分辨率

        # Define StyleBlocks ( adjust channels and resolutions accordingly )
        self.style_blocks = nn.ModuleList([
            StyleBlock(512, 512, latent_dim, resolution=4), # 4x4
            StyleBlock(512, 512, latent_dim, resolution=4, upsample=True), # 8x8
            StyleBlock(512, 512, latent_dim, resolution=8), # 8x8
            StyleBlock(512, 512, latent_dim, resolution=8, upsample=True), # 16x16
            StyleBlock(512, 512, latent_dim, resolution=16), # 16x16
            StyleBlock(512, 512, latent_dim, resolution=16, upsample=True), # 32x32
            StyleBlock(512, 512, latent_dim, resolution=32), # 32x32
            StyleBlock(512, 512, latent_dim, resolution=32, upsample=True), # 64x64
            StyleBlock(512, 512, latent_dim, resolution=64)  # 64x64
        ])

        # Output convolution
        self.to_rgb = nn.Sequential(
            nn.Conv2d(512, img_channels, kernel_size=1),
            nn.Tanh()
        )

    def forward(self, z, noise=None):
        w = self.mapping_network(z) # (batch_size, latent_dim)
        x = self.constant_input.repeat(z.shape[0], 1, 1, 1) # (batch_size, 512, 4, 4)

        for style_block in self.style_blocks:
            if isinstance(style_block, StyleBlock):
                resolution = style_block.noise_injection.resolution
                noise_for_block = torch.randn(x.shape[0], 1, resolution, resolution, device=x.device) if noise is None else noise # 为每个block生成不同的噪声
                x = style_block(x, w, noise_for_block)
            else:
                x = style_block(x)

        img = self.to_rgb(x)
        return img

3.4 StyleGAN 训练技巧

  • 渐进式增长(Progressive Growing): 从低分辨率开始训练,然后逐渐增加分辨率。这可以稳定训练,并加速收敛。
  • 混合正则化(Mixing Regularization): 在训练过程中,随机选择两个不同的风格向量,并将它们混合在一起。这可以防止生成器过度拟合。
  • 路径长度正则化(Path Length Regularization): 约束生成器的输出相对于隐空间的变化。这可以提高生成图像的质量。

4. CycleGAN与StyleGAN对比

特性 CycleGAN StyleGAN
应用场景 无配对图像转换 高分辨率图像生成,风格控制
核心思想 循环一致性损失 风格调制,中间隐空间
网络架构 ResNet生成器,PatchGAN判别器 映射网络,StyleBlock生成器
训练难度 中等 较高
可控性 有限,主要控制图像的整体风格 较高,可以控制图像的各个层级的风格
图像质量 依赖于数据集和训练,可能出现伪影 较高,可以生成逼真的高分辨率图像
对数据集的要求 两个域的数据集,不需要配对 大规模数据集,通常是人脸图像

5. 总结:两种GAN的应用场景与实现

我们深入探讨了CycleGAN和StyleGAN的原理和实现。CycleGAN擅长无配对图像转换,通过循环一致性损失实现跨域转换。StyleGAN则专注于高分辨率图像生成,通过风格调制实现对图像风格的精细控制。选择哪种GAN取决于具体的应用场景和对图像质量和可控性的需求。掌握这些高级GAN模型,能帮助我们解决更复杂的生成式建模问题。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注