Press "Enter" to skip to content

好像还挺好玩的GAN重制版3——Pytorch搭建DCGAN利用深度卷积神经网络实现图片生成

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

好像还挺好玩的GAN重制版2——Pytorch搭建DCGAN利用深度卷积神经网络实现图片生成

 

学习前言

 

我又死了我又死了我又死了!

源码下载地址

 

https://github.com/bubbliiiing/dcgan-pytorch

 

喜欢的可以点个star噢。

 

网络构建

 

一、什幺是DCGAN

 

DCGAN的全称是Deep Convolutional Generative Adversarial Networks,翻译为 深度卷积对抗生成网络 。

 

它是由Alec Radford在论文Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks中提出的。

 

实际上它就是在GAN的基础上增加深度卷积网络结构。

 

论文中给出的DCGAN结构如图所示。其使用反卷积将特征层的高宽不断扩大,整体结构看起来像普通神经网络的逆过程。

二、生成网络的构建

 

对于生成网络来讲,它的 目的是生成假图片 ,它的 输入是正态分布随机数。输出是假图片 。

 

在GAN当中,我们将这个正态分布随机数长度定义为100,在经过处理后,我们会得到一个(64,64,3)的假图片。

 

在处理过程中,我们会使用到反卷积,反卷积的概念是相对于正常卷积的, 在正常卷积下,我们的特征层的高宽会不断被压缩 ; 在反卷积下,我们的特征层的高宽会不断变大 。

在DCGAN的生成网络中,我们首先利用一个全连接,将输入长条全连接到16,384(4x4x1024)这样一个长度上,这样我们才可以对这个全连接的结果进行reshape,使它变成(4,4,1024)的特征层。

 

在获得这个特征层之后,我们就可以利用反卷积进行上采样了。

 

在每次反卷积后,特征层的高和宽会变为原来的两倍,在四次反卷积后,我们特征层的shape变化是这样的:

 

( 4 , 4 , 1024 ) − > ( 8 , 8 , 512 ) − > ( 16 , 16 , 256 ) − > ( 32 , 32 , 128 ) − > ( 64 , 64 , 3 ) 。 (4,4,1024)->(8,8,512)->(16,16,256)->(32,32,128)->(64,64,3)。 ( 4 , 4 , 1 0 2 4 ) − > ( 8 , 8 , 5 1 2 ) − > ( 1 6 , 1 6 , 2 5 6 ) − > ( 3 2 , 3 2 , 1 2 8 ) − > ( 6 4 , 6 4 , 3 ) 。

 

此时我们再进行一次tanh激活函数,我们就可以获得一张假图片了。

 

实现代码如下:

 

def conv_out_size_same(size, stride):
    return int(math.ceil(float(size) / float(stride)))
    
class generator(nn.Module):
    def __init__(self, d = 128, input_shape = [64, 64]):
        super(generator, self).__init__()
        s_h, s_w    = input_shape[0], input_shape[1]
        s_h2, s_w2  = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4  = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8  = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        self.s_h16, self.s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
        self.linear     = nn.Linear(100, self.s_h16 * self.s_w16 * d * 8)
        self.linear_bn  = nn.BatchNorm2d(d * 8)
        self.deconv1    = nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm2d(d * 4)
        self.deconv2    = nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d * 2)
        self.deconv3    = nn.ConvTranspose2d(d * 2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        
        self.deconv4    = nn.ConvTranspose2d(d, 3, 4, 2, 1)
        self.relu       = nn.ReLU()
    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(0.1, 0.02)
                m.bias.data.fill_(0)
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
    def forward(self, x):
        bs, _ = x.size()
        x = self.linear(x)
        x = x.view([bs, -1, self.s_h16, self.s_w16])
        x = self.relu(self.linear_bn(x))
        x = self.relu(self.deconv1_bn(self.deconv1(x)))
        x = self.relu(self.deconv2_bn(self.deconv2(x)))
        x = self.relu(self.deconv3_bn(self.deconv3(x)))
        x = torch.tanh(self.deconv4(x))
        return x

 

三、判断网络的构建

 

对于生成网络来讲,它的 目的是生成假图片 ,它的 输入是正态分布随机数。输出是假图片 。

 

对于判断网络来讲,它的 目的是判断输入图片的真假 ,它的 输入是图片,输出是判断结果 。

 

判断结果处于0-1之间,利用 接近1代表判断为真图片,接近0代表判断为假图片。

 

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

 

实现代码如下:

 

class discriminator(nn.Module):
    def __init__(self, d = 128, input_shape = [64, 64]):
        super(discriminator, self).__init__()
        s_h, s_w    = input_shape[0], input_shape[1]
        s_h2, s_w2  = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4  = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8  = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        self.s_h16, self.s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
        # 64,64,3 -> 32,32,128
        self.conv1      = nn.Conv2d(3, d, 4, 2, 1)
        # 32,32,128 -> 16,16,256
        self.conv2      = nn.Conv2d(d, d * 2, 4, 2, 1)
        self.conv2_bn   = nn.BatchNorm2d(d * 2)
        # 16,16,256 -> 8,8,512
        self.conv3      = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn   = nn.BatchNorm2d(d * 4)
        # 8,8,512 -> 4,4,1024
        self.conv4      = nn.Conv2d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn   = nn.BatchNorm2d(d * 8)
        # 4,4,1024 -> 1,1,1
        self.linear     = nn.Linear(self.s_h16 * self.s_w16 * d * 8, 1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.sigmoid    = nn.Sigmoid()
    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(0.1, 0.02)
                m.bias.data.fill_(0)
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0.0, 0.02)
                m.bias.data.fill_(0)
    def forward(self, x):
        bs, _, _, _ = x.size()
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.conv2_bn(self.conv2(x)))
        x = self.leaky_relu(self.conv3_bn(self.conv3(x)))
        x = self.leaky_relu(self.conv4_bn(self.conv4(x)))
        x = x.view([bs,-1])
        x = self.sigmoid(self.linear(x))
        return x.squeeze()

 

训练思路

 

DCGAN的训练可以分为生成器训练和判别器训练:

 

每一个step中一般先训练判别器,然后训练生成器。

 

一、判别器的训练

 

在 训练判别器的时候我们希望判别器可以判断输入图片的真伪 ,因此我们的 输入就是真图片、假图片和它们对应的标签 。

 

因此判别器的训练步骤如下:

 

1、随机选取batch_size个真实的图片。

 

2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。

二、生成器的训练

 

在 训练生成器的时候我们希望生成器可以生成极为真实的假图片 。因此我们在训练生成器需要知道 判别器认为什幺图片是真图片。

 

因此生成器的训练步骤如下:

 

1、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。

利用DCGAN生成图片

 

DCGAN的库整体结构如下:

一、数据集的准备

 

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

 

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。

 

此时生成根目录下面的train_lines.txt。

三、模型训练

 

在完成数据集处理后,运行train.py即可开始训练。

训练过程中,可在results文件夹内查看训练效果:

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。