本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
目录
3.原始GAN和CGAN的区别
1.原始GAN的缺点
生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。
针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。
2.CGAN中心思想
CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。
3.原始GAN和CGAN的区别
从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。
从模型上来看,如下图所示
为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。
4.CGAN代码实现
#导入库 import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import data import torchvision #加载图片 from torchvision import transforms #图片变换 import numpy as np import matplotlib.pyplot as plt #绘图 import os import glob from PIL import Image #独热编码 def one_hot(x,class_count=10): return torch.eye(class_count)[x,:] transform = transforms.Compose([ transforms.ToTensor(), #取值范围会被归一化到(0,1)之间 transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5 ]) #加载数据集 dataset = torchvision.datasets.MNIST('data', train=True, transform=transform, target_transform = one_hot, download = True) dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True) #定义生成器 class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() self.linear1 = nn.Linear(100,128*7*7) self.bn1=nn.BatchNorm1d(128*7*7) self.linear2 = nn.Linear(10,128*7*7) self.bn2=nn.BatchNorm1d(128*7*7) self.deconv1 = nn.ConvTranspose2d(256,128, kernel_size=(3,3), stride=1, padding=1) #生成(128,7,7)的二维图像 self.bn3=nn.BatchNorm2d(128) self.deconv2 = nn.ConvTranspose2d(128,64, kernel_size=(4,4), stride=2, padding=1) #生成(64,14,14)的二维图像 self.bn4=nn.BatchNorm2d(64) self.deconv3 = nn.ConvTranspose2d(64,1, kernel_size=(4,4), stride=2, padding=1) #生成(1,28,28)的二维图像 def forward(self,x1,x2): x1=F.relu(self.linear1(x1)) x1=self.bn1(x1) x1=x1.view(-1,128,7,7) x2=F.relu(self.linear2(x2)) x2=self.bn2(x2) x2=x2.view(-1,128,7,7) x=torch.cat([x1,x2],axis=1) #batch, 256, 7, 7 x=x.view(-1,256,7,7) x=F.relu(self.deconv1(x)) x=self.bn3(x) x=F.relu(self.deconv2(x)) x=self.bn4(x) x=torch.tanh(self.deconv3(x)) return x #定义判别器 #输入:1,28,28图片和长度为10的condition class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.linear = nn.Linear(10,1*28*28) self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2) self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2) self.bn = nn.BatchNorm2d(128) self.fc = nn.Linear(128*6*6,1) def forward(self,x1,x2): #x1代表label,x2代表image x1=F.leaky_relu(self.linear(x1)) x1=x1.view(-1,1,28,28) x=torch.cat([x1,x2],axis=1) #shape:batch,2,28,28 x= F.dropout2d(F.leaky_relu(self.conv1(x))) x= F.dropout2d(F.leaky_relu(self.conv2(x)) ) #(batch,128,6,6) x = self.bn(x) x = x.view(-1,128*6*6) #展平 x = torch.sigmoid(self.fc(x)) return x #模型训练 #设备的配置 device='cuda' if torch.cuda.is_available() else 'cpu' #初化生成器和判别器把他们放到相应的设备上 gen = Generator().to(device) dis = Discriminator().to(device) #交叉熵损失函数 loss_fn = torch.nn.BCELoss() #训练器的优化器 d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5) #训练生成器的优化器 g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4) #定义可视化函数 def generate_and_save_images(model,epoch,label_input,noise_input): prediction = np.squeeze(model(noise_input,label_input).cpu().numpy()) fig = plt.figure(figsize=(4,4)) for i in range(prediction.shape[0]): plt.subplot(4,4,i+1) plt.imshow((prediction[i]+1)/2,cmap='gray') plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) plt.show() #设置生成绘图图片的随机张量,这里可视化16张图片 #生成16个长度为100的随机正态分布张量 noise_seed = torch.randn(16,100,device=device) label_seed = torch.randint(0,10,size=(16,)) label_seed_onehot = one_hot(label_seed).to(device) D_loss = [] #记录训练过程中判别器的损失 G_loss = [] #记录训练过程中生成器的损失 #训练循环 for epoch in range(10): #初始化损失值 D_epoch_loss = 0 G_epoch_loss = 0 count = len(dl.dataset) #返回批次数 #对数据集进行迭代 for step,(img,label) in enumerate(dl): img =img.to(device) #把数据放到设备上 label = label.to(device) size = img.shape[0] #img的第一位是size,获取批次的大小 random_seed = torch.randn(size,100,device=device) #判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化 d_optimizer.zero_grad()#梯度归零 #判别器对于真实图片产生的损失 real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果 d_real_loss = loss_fn(real_output, torch.ones_like(real_output,device=device) ) d_real_loss.backward()#计算梯度 #在生成器上去计算生成器的损失,优化目标是判别器上的参数 generated_img = gen(random_seed,label) #得到生成的图片 #因为优化目标是判别器,所以对生成器上的优化目标进行截断 fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了 #判别器在生成图像上产生的损失 d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output,device=device) ) d_fake_loss.backward() #判别器损失 disc_loss = d_real_loss + d_fake_loss #判别器优化 d_optimizer.step() #生成器上损失的构建和优化 g_optimizer.zero_grad() #先将生成器上的梯度置零 fake_output = dis(label,generated_img) gen_loss = loss_fn(fake_output, torch.ones_like(fake_output,device=device) ) #生成器损失 gen_loss.backward() g_optimizer.step() #累计每一个批次的loss with torch.no_grad(): D_epoch_loss +=disc_loss G_epoch_loss +=gen_loss #求平均损失 with torch.no_grad(): D_epoch_loss /=count G_epoch_loss /=count D_loss.append(D_epoch_loss) G_loss.append(G_epoch_loss) #训练完一个Epoch,打印提示并绘制生成的图片 print("Epoch:",epoch) print(label_seed) generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)
5.运行结果
因篇幅有限,只展示一部分运行结果
6.CGAN缺陷
CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。
Be First to Comment