Press "Enter" to skip to content

好像还挺好玩的GAN重制版4——Pytorch搭建SRGAN平台进行图片超分辨率提升

好像还挺好玩的GAN重制版4——Pytorch搭建SRGAN平台进行图片超分辨率提升

 

学习前言

 

我又死了我又死了我又死了!

源码下载地址

 

https://github.com/bubbliiiing/srgan-pytorch

 

喜欢的可以点个star噢。

 

网络构建

 

一、什幺是SRGAN

 

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

 

如果将SRGAN看作一个黑匣子,其主要的功能 就是输入一张低分辨率图片,生成高分辨率图片。

该文章提到, 普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节 。

 

SRGAN利用 感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感 。

 

二、生成网络的构建

 

生成网络的构成如上图 所示, 生成网络 的作用是 输入一张低分辨率图片,生成高分辨率图片。 :

 

SRGAN的生成网络由 三个部分 组成。

 

1、 低分辨率图像进入后会经过一个卷积+RELU函数 。

 

2、然后 经过B个残差网络结构 ,每个残差结构都 包含两个卷积+标准化+RELU,还有一个残差边。

 

3、然后进入 上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升 。

 

前两个部分用于特征提取,第三部分用于提高分辨率。

 

import math
import torch
from torch import nn
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        short_cut = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x + short_cut
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU(in_channels)
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x
class Generator(nn.Module):
    def __init__(self, scale_factor, num_residual=16):
        upsample_block_num = int(math.log(scale_factor, 2))
        super(Generator, self).__init__()
        self.block_in = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU(64)
        )
        self.blocks = []
        for _ in range(num_residual):
            self.blocks.append(ResidualBlock(64))
        self.blocks = nn.Sequential(*self.blocks)
        
        self.block_out = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.upsample = nn.Sequential(*self.upsample)
    def forward(self, x):
        x = self.block_in(x)
        short_cut = x
        x = self.blocks(x)
        x = self.block_out(x)
        upsample = self.upsample(x + short_cut)
        return torch.tanh(upsample)

 

三、判别网络的构建

 

判别网络的构成如上图 所示:

 

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。

 

对于判断网络来讲,它的 目的是判断输入图片的真假 ,它的 输入是图片,输出是判断结果 。

 

判断结果处于0-1之间,利用 接近1代表判断为真图片,接近0代表判断为假图片。

 

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

 

实现代码如下:

 

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )
    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

 

训练思路

 

SRGAN的训练可以分为生成器训练和判别器训练:

 

每一个step中一般先训练判别器,然后训练生成器。

 

一、判别器的训练

 

在 训练判别器的时候我们希望判别器可以判断输入图片的真伪 ,因此我们的 输入就是真图片、假图片和它们对应的标签 。

 

因此判别器的训练步骤如下:

 

1、随机选取batch_size个真实高分辨率图片。

 

2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。

二、生成器的训练

 

在 训练生成器的时候我们希望生成器可以生成极为真实的假图片 。因此我们在训练生成器需要知道 判别器认为什幺图片是真图片。

 

因此生成器的训练步骤如下:

 

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。

利用SRGAN生成图片

 

SRGAN的库整体结构如下:

一、数据集的准备

 

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

二、数据集的处理

 

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。

 

此时生成根目录下面的train_lines.txt。

三、模型训练

 

在完成数据集处理后,运行train.py即可开始训练。

训练过程中,可在results文件夹内查看训练效果:

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注