Press "Enter" to skip to content

CGAN理论讲解及代码实现

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

目录

 

3.原始GAN和CGAN的区别

 

1.原始GAN的缺点

 

生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。

 

针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。

 

2.CGAN中心思想

 

CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional  GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。

 

3.原始GAN和CGAN的区别

 

从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。

 

 

从模型上来看,如下图所示

 

 

为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。

 

4.CGAN代码实现

 

#导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换
import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image
#独热编码
def one_hot(x,class_count=10):
    return torch.eye(class_count)[x,:]
transform = transforms.Compose([
    transforms.ToTensor(), #取值范围会被归一化到(0,1)之间
    transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5
])
#加载数据集
dataset = torchvision.datasets.MNIST('data',
                                    train=True,
                                    transform=transform,
                                    target_transform = one_hot,
                                    download = True)
dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True)
#定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.linear1 = nn.Linear(100,128*7*7)
        self.bn1=nn.BatchNorm1d(128*7*7)
        self.linear2 = nn.Linear(10,128*7*7)
        self.bn2=nn.BatchNorm1d(128*7*7)
        
        self.deconv1 = nn.ConvTranspose2d(256,128,
                                         kernel_size=(3,3),
                                         stride=1,
                                         padding=1)  #生成(128,7,7)的二维图像
        self.bn3=nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128,64,
                                         kernel_size=(4,4),
                                         stride=2,
                                         padding=1)  #生成(64,14,14)的二维图像
        self.bn4=nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64,1,
                                         kernel_size=(4,4),
                                         stride=2,
                                         padding=1)  #生成(1,28,28)的二维图像
        
    def forward(self,x1,x2):
        x1=F.relu(self.linear1(x1))
        x1=self.bn1(x1)
        x1=x1.view(-1,128,7,7)
        x2=F.relu(self.linear2(x2))
        x2=self.bn2(x2)
        x2=x2.view(-1,128,7,7)
        x=torch.cat([x1,x2],axis=1)  #batch, 256, 7, 7
        x=x.view(-1,256,7,7)
        x=F.relu(self.deconv1(x))
        x=self.bn3(x)
        x=F.relu(self.deconv2(x))
        x=self.bn4(x)
        x=torch.tanh(self.deconv3(x))
        return x
#定义判别器
#输入:1,28,28图片和长度为10的condition
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.linear = nn.Linear(10,1*28*28)
        self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2)
        self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*6*6,1)
    def forward(self,x1,x2): #x1代表label,x2代表image
        x1=F.leaky_relu(self.linear(x1))
        x1=x1.view(-1,1,28,28)
        x=torch.cat([x1,x2],axis=1)  #shape:batch,2,28,28                
        x= F.dropout2d(F.leaky_relu(self.conv1(x)))
        x= F.dropout2d(F.leaky_relu(self.conv2(x)) )  #(batch,128,6,6)
        x = self.bn(x)
        x = x.view(-1,128*6*6) #展平
        x = torch.sigmoid(self.fc(x))
        return x
#模型训练
#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
#初化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
#训练器的优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
#训练生成器的优化器
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4)
#定义可视化函数
def generate_and_save_images(model,epoch,label_input,noise_input):
    prediction = np.squeeze(model(noise_input,label_input).cpu().numpy())
    fig = plt.figure(figsize=(4,4))
    for i in range(prediction.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2,cmap='gray')
        plt.axis('off')
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
#设置生成绘图图片的随机张量,这里可视化16张图片
#生成16个长度为100的随机正态分布张量
noise_seed = torch.randn(16,100,device=device)
label_seed = torch.randint(0,10,size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)
D_loss = [] #记录训练过程中判别器的损失
G_loss = [] #记录训练过程中生成器的损失
#训练循环
for epoch in range(10):
    #初始化损失值
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dl.dataset) #返回批次数
    #对数据集进行迭代
    for step,(img,label) in enumerate(dl):
        img =img.to(device) #把数据放到设备上
        label = label.to(device)
        size = img.shape[0] #img的第一位是size,获取批次的大小
        random_seed = torch.randn(size,100,device=device)
        
        #判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化
        d_optimizer.zero_grad()#梯度归零
        #判别器对于真实图片产生的损失
        real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果
        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output,device=device)
                              )
        d_real_loss.backward()#计算梯度
        
        #在生成器上去计算生成器的损失,优化目标是判别器上的参数
        generated_img = gen(random_seed,label) #得到生成的图片
        #因为优化目标是判别器,所以对生成器上的优化目标进行截断
        fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了
        #判别器在生成图像上产生的损失
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output,device=device)
                              )
        d_fake_loss.backward()
        #判别器损失
        disc_loss = d_real_loss + d_fake_loss
        #判别器优化
        d_optimizer.step()
        
        
        #生成器上损失的构建和优化
        g_optimizer.zero_grad() #先将生成器上的梯度置零
        fake_output = dis(label,generated_img)
        gen_loss = loss_fn(fake_output,
                              torch.ones_like(fake_output,device=device)
                          )  #生成器损失
        gen_loss.backward()
        g_optimizer.step()
        #累计每一个批次的loss
        with torch.no_grad():
            D_epoch_loss +=disc_loss
            G_epoch_loss +=gen_loss
    #求平均损失
    with torch.no_grad():
            D_epoch_loss /=count
            G_epoch_loss /=count
            D_loss.append(D_epoch_loss)
            G_loss.append(G_epoch_loss)
            #训练完一个Epoch,打印提示并绘制生成的图片
            print("Epoch:",epoch)
            print(label_seed)
            generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)

 

5.运行结果

 

因篇幅有限,只展示一部分运行结果

 

 

 

 

 

 

 

6.CGAN缺陷

 

CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。