Press "Enter" to skip to content

SRGAN 图像超分辨率重建(Keras)

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

文章目录

 

SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络。训练完网络后。只用生成器来重建低分辨率图像。网络结构主要使用生成器(Generator)和判别器(Discriminator)。训练过程不太稳定。一般用于卫星图像和遥感图像的图像重建。

 

这里我们使用的高分辨率的数据集 (DIV2K)

 

数据集下载链接:链接:https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg 提取码:luly

 

github代码地址:https://github.com/jiantenggei/srgan

 

一、SRGAN

 

1.训练步骤

 

SRGAN 网络的训练思路如下图所示:

训练步骤如下:

 

(1) 将低分辨率输入到生成网络,生成高分辨率图像。

 

(2) 将高分辨率图像输入的判别网络判别真假,与0和1进行对比

 

(3) 将原始高分辨率图像和生成的高分辨率图像分别用VGG19 的前9层提取特征,将提取的特征计算loss。

 

(4). 将loss返回给生成器继续训练。

 

这就是SRGAN 的训练流程了。

接下来我们一一去实现上述步骤。

2.生成器

 

生成器网络结构如下图所示:

生成器主要有两部分构成,第一部分是residual block 残差块(图中红色方块),第二部分是上采样部分(图中蓝色方块)用来放大图像。

 

残差块:包含一个两个3×3的卷积

 

上采样:使用 UpSampling2D 实现

 

生成器代码如下所示:

由于把整个SRGAN 定义成一个类的形式 里面没有出现的参数我会卸载注解中

def build_generator(self):
        def residual_block(layer_input, filters):
            """Residual block described in paper #残差块"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d
        def deconv2d(layer_input):
            """Layers used during upsampling #上采样块"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u
        # Low resolution image input
        # self.lr_shape 低分辨率图像的大小
        img_lr = Input(shape=self.lr_shape)
        # Pre-residual block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)
        # Propogate through residual blocks
        # self.gf 生成器使用残差快的个数
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)
        # Post-residual block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])
        # Upsampling
        #上采样
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)
        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
        return Model(img_lr, gen_hr)

 

3.判别器

 

判别器主要用于判断生成图片的真假。与0和1比较,1代表真图片,0代表假图片。这里的0和1 是与判别器输出大小想用的向量,而不是单纯的0,1判别器网络结果如下所示:

判别网络由一个个包含卷积、BN、和LeakyRelu 激活函数的块组成,最后输出1或0, 实际上就相当于是一个二分类的分类网络,代码如下所示:

 

def build_discriminator(self):
        #这里self.df =64
        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer #判别器主要包含的卷积块"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d
        # Input img
        #self.hr_shape  高分辨率图片的大小
        d0 = Input(shape=self.hr_shape)
        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)
        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)
        return Model(d0, validity)```

 

网络主要分为生成器和判别器,训练时相互对抗,以达到一个很好的平衡为目的。

 

二、其他准备

 

1.数据读取

 

在训练时,我们会将128×128的图像放大成512×512 的图像。生成网络就是为了保证放大后的图片依然清晰。

 

数据读取过程先将图片reshape 成512×512的图片作为监督数据,将缩小成128×128大小的图片作为训练的数据。读取代码如下:

 

import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
import cv2
#======================
# 用于读取高分辨率数据集
#======================
#数据预处理,把原图像处理成小图和大图
class DataLoader():
    #初始化,重构后清晰图像的大小
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res
    #从文件夹里读数据
    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        path = glob('./datasets/%s/train/*' % (self.dataset_name))
        #随机选图片训练,可能很多张
        batch_images = np.random.choice(path, size=batch_size)
        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)
            #计算缩小的数据
            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)
            #将图片缩小
            img_hr = cv2.resize(img, self.img_res)
            img_lr = cv2.resize(img, (low_h, low_w))
            # If training => do random flip,如果是训练模式,翻转,做数据增强
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)
            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)
        #归一化 0-255,255/127.5=2,0-2之间,-1就归一化到-1到1之间
        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.
        return imgs_hr, imgs_lr #矩阵,列表里放的矩阵
    #读图片,转化到RGB
    def imread(self, path):
         img =cv2.imread(path)
         return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

 

2.VGG19提取特征

 

使用用预训练的VGG19 前9层提取特征,原版代码会有错,错误提示如 这篇博客所示【错误】

 

将代码更改为如下所示即可:

 

def build_vgg(self):
        # 建立VGG模型,只使用第9层的特征
        vgg = VGG19(weights="imagenet",input_shape=self.hr_shape,include_top=False)
        return Model(vgg.input, outputs=vgg.layers[9].output)

 

4.训练过程完整代码

 

from re import S
import cv2
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        self.lr_height = 128                 # Low resolution height
        self.lr_width = 128                  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        self.hr_height = self.lr_height*4   # High resolution height
        self.hr_width = self.lr_width*4     # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)
        # Number of residual blocks in the generator
        self.n_residual_blocks = 16
        optimizer = Adam(0.0002, 0.5)
        # VGG19 提取特征
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        #数据路径
        self.dataset_name = 'DIV'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))
        # 输出维度,方便构建 0,1 矩阵,判别器计算损失
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)
        # 生成器和判别器卷积核的个数
        self.gf = 64
        self.df = 64
        # 配置构建判别器 mse 损失
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        # 生成器
        self.generator = self.build_generator()
        # 高低分辨率形状
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)
        # 生成器生成假图片
        fake_hr = self.generator(img_lr)
        # 特征提取
        fake_features = self.vgg(fake_hr)
       #一开始不训练判别器
        self.discriminator.trainable = False
      
        validity = self.discriminator(fake_hr)
        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)
    def build_vgg(self):
        # 建立VGG模型,只使用第9层的特征
        vgg = VGG19(weights="imagenet",input_shape=self.hr_shape,include_top=False)
        return Model(vgg.input, outputs=vgg.layers[9].output)
    def build_generator(self):
        def residual_block(layer_input, filters):
            """Residual block described in paper"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d
        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u
        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)
        # Pre-residual block
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)
        # Propogate through residual blocks
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)
        # Post-residual block
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])
        # Upsampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)
        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
        return Model(img_lr, gen_hr)
    def build_discriminator(self):
        #这里self.df =64
        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d
        # Input img
        d0 = Input(shape=self.hr_shape)
        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)
        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)
        return Model(d0, validity)
    def train(self, epochs, batch_size=1, sample_interval=50):
        start_time = datetime.datetime.now()
        for epoch in range(epochs):
            # ----------------------
            #  训练生成器
            # ----------------------
            # 加载数据
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
            # 低分辨率数据生成高分辨率数据
            fake_hr = self.generator.predict(imgs_lr)
            # 0,1
            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)
            # 计算loss
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            # ------------------
            #  训练生成器
            # ------------------
            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,) + self.disc_patch)
            # vgg19 提取特征
            image_features = self.vgg.predict(imgs_hr)
            # 生成器
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("%d time: %s" % (epoch, elapsed_time))
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                self.generator.save('weights/epoch%s'%str(epoch)+'.h5')
    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2
        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)
        #-----------------------------------------------
        # 主要是因为opencv 读取的图片用plt 显示时,
        # 会颜色不对,这里主要解决这一问题
        #--------------------------------------------------
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5
        # Save generated images and the high resolution originals
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()
        # Save low resolution images for comparison
        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()
if __name__ == '__main__':
    gan = SRGAN()
    gan.train(epochs=30000, batch_size=2, sample_interval=200)

 

训练时,在image目录下会出现这样的图片

 

低分辨率图:

(右边原图,左边生成的图):

5. 预测过程

 

预测过程只需要生成器即可,并且不需要限制图片大小,把生成器拿出来单独使用:

 

#生成器代码
from keras.layers import Input
from keras.layers import BatchNormalization, Activation, Add
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import  Model
def build_generator():
    def residual_block(layer_input, filters):
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
        d = Activation('relu')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Add()([d, layer_input])
        return d
    def deconv2d(layer_input):
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
        u = Activation('relu')(u)
        return u
    img_lr = Input(shape=[None,None,3])
    # 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数
    c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
    c1 = Activation('relu')(c1)
    # 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。
    r = residual_block(c1, 64)
    for _ in range(15):
        r = residual_block(r, 64)
    # 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。
    c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c2 = Add()([c2, c1])
    u1 = deconv2d(c2)
    u2 = deconv2d(u1)
    gen_hr = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
    return Model(img_lr, gen_hr)

 

预测部分代码:

 

from srgan import SRGAN
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from generator_model import build_generator
before_image = Image.open(r"Female_person.jpg")
gen_model = build_generator()
gen_model.load_weights('weights\epoch14800.h5')
# gen_model.summary()
new_img = Image.new('RGB', before_image.size, (128, 128, 128))
new_img.paste(before_image)
# plt.imshow(new_img)
# plt.show()
new_image = np.array(new_img)/127.5 - 1
# 三维变4维  因为神经网络的输入是四维的
new_image = np.expand_dims(new_image, axis=0)  # [batch_size,w,h,c]
fake = (gen_model.predict(new_image)*0.5 + 0.5)*255
#将np array 形式的图片转换为unit8  把数据转换为图
fake = Image.fromarray(np.uint8(fake[0]))
fake.save("out.png")
titles = ['Generated', 'Original']
plt.subplot(1, 2, 1)
plt.imshow(before_image)
plt.subplot(1, 2, 2)
plt.imshow(fake)
plt.show()

 

重建效果:

感谢师姐为我在学习该网络时的耐心指点

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。