Press "Enter" to skip to content

机器学习笔记 – 基于Keras的GAN:图像去模糊的应用

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

一、GAN概述

 

在生成对抗网络中,两个网络相互训练。生成器通过创建虚假输入来误导鉴别器。鉴别器告诉输入是真实的还是人造的。

 

 

GAN 训练过程有3个主要步骤

 

1、使用生成器  根据噪声创建虚假输入

 

2、用真假输入训练鉴别器

 

3、训练整个模型:模型是用链接到生成器的鉴别器构建的。

 

鉴别器的权重在第三步中被冻结。链接两个网络的原因是生成器的输出没有可能的反馈。 我们唯一的衡量标准是鉴别器是否接受了生成的样本。

 

二、数据集

 

Ian Goodfellow 首先应用 GAN 模型来生成 MNIST 数据。在本教程中,我们使用生成对抗网络进行图像去模糊。因此,生成器的输入不是噪声而是模糊图像。

 

数据集是 GOPRO 数据集。您可以下载精简版 (9GB) 或 完整版 (35GB)。它包含 来自多个街景的人为模糊图像。数据集按场景分解到子文件夹中。

 

精简版 https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing 完整版 https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing 精简版百度云下载

 

链接:https://pan.baidu.com/s/1fS6QIoRhpngyGwn4eHOJDw 
提取码:hky7

 

我们首先将图像分配到两个文件夹 A(模糊)和 B(清晰)中。此 A&B 架构对应于原始 pix2pix 文章。

 

Image-to-Image Translation with Conditional Adversarial Networks https://phillipi.github.io/pix2pix/

 

三、网络模型

 

1、生成器

 

生成器旨在再现清晰的图像。该网络基于ResNet 块。它跟踪应用于原始模糊图像的演变。上一节提及的出版物还使用了基于UNet 的版本。两个块都应该在图像去模糊方面表现良好。

 

 

核心是用于对原始图像进行上采样的9个ResNet块。

 

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout
def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
    """
    Instanciate a Keras Resnet Block using sequential API.
    :param input: Input tensor
    :param filters: Number of filters to use
    :param kernel_size: Shape of the kernel for the convolution
    :param strides: Shape of the strides for the convolution
    :param use_dropout: Boolean value to determine the use of dropout
    :return: Keras Model
    """
    x = ReflectionPadding2D((1,1))(input)
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    if use_dropout:
        x = Dropout(0.5)(x)
    x = ReflectionPadding2D((1,1))(x)
    x = Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x = BatchNormalization()(x)
    # Two convolution layers followed by a direct connection between input and output
    merged = Add()([input, x])
    return merged

 

这个 ResNet 层基本上是一个卷积层,输入和输出相加形成最终输出。

 

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from layer_utils import ReflectionPadding2D, res_block
ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9
def generator_model():
    """Build generator architecture."""
    # Current version : ResNet block
    inputs = Input(shape=image_shape)
    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    # Apply 9 ResNet blocks
    mult = 2**n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf*mult, use_dropout=True)
    # Decrease filter number to 3 (RGB)
    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    x = ReflectionPadding2D((3,3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
    x = Activation('tanh')(x)
    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)
    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

 

将9个ResNet 块应用于输入的上采样版本。我们添加从输入到输出的连接并除以2以保持标准化输出。这就是生成器。

 

2、 鉴别器

 

目的是确定输入图像是否是人工创建的。因此,判别器的架构是卷积的, 输出单个值。

 

from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)
def discriminator_model():
    """Build discriminator architecture."""
    n_layers, use_sigmoid = 3, False
    inputs = Input(shape=input_shape_discriminator)
    x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
    x = LeakyReLU(0.2)(x)
    nf_mult, nf_mult_prev = 1, 1
    for n in range(n_layers):
        nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)
    nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
    if use_sigmoid:
        x = Activation('sigmoid')(x)
    x = Flatten()(x)
    x = Dense(1024, activation='tanh')(x)
    x = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model

 

最后一步是构建完整的模型。这个GAN的一个特点是输入是真实图像而不是噪声。因此,我们 对生成器的输出有直接反馈。

 

from keras.layers import Input
from keras.models import Model
def generator_containing_discriminator_multiple_outputs(generator, discriminator):
    inputs = Input(shape=image_shape)
    generated_images = generator(inputs)
    outputs = discriminator(generated_images)
    model = Model(inputs=inputs, outputs=[generated_images, outputs])
    return model

 

四、训练

 

1、损失函数

 

我们在生成器末端和完整模型末端两个级别提取损失。

 

第一个是直接在生成器的输出上计算的感知损失。第一个损失确保了 GAN 模型面向去模糊任务。它比较 VGG 的第一个卷积的输出。

 

import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model
image_shape = (256, 256, 3)
def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

 

第二个损失是对整个模型的输出执行的Wasserstein 损失。它取两个图像之间差异的平均值。众所周知,它可以提高生成对抗网络的收敛性。

 

import keras.backend as K
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

 

2、训练程序

 

第一步是加载数据并初始化所有模型。我们使用自定义函数来加载数据集,并为我们的模型添加 Adam 优化器。我们设置 Keras 可训练选项以防止判别器训练。

 

# Load dataset
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']
# Initialize models
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)
# Initialize optimizers
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
# Compile models
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True

 

然后,我们开始启动epoch并将数据集分成批次。

 

for epoch in range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0] / batch_size))
  # Randomize images into batches
  permutated_indexes = np.random.permutation(x_train.shape[0])
  for index in range(int(x_train.shape[0] / batch_size)):
      batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
      image_blur_batch = x_train[batch_indexes]
      image_full_batch = y_train[batch_indexes]

 

最后,我们基于这两个损失依次训练鉴别器和生成器。我们使用生成器生成虚假输入。我们训练鉴别器以区分假输入和真实输入,并训练整个模型。

 

for epoch in range(epoch_num):
  for index in range(batches):
    # [Batch Preparation]
    # Generate fake inputs
    generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
    
    # Train multiple times discriminator on real and fake inputs
    for _ in range(critic_updates):
        d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
        d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
        d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
    d.trainable = False
    # Train generator only on discriminator's decision and generated images
    d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])
    d.trainable = True

 

完整代码参考下面链接

 

https://github.com/raphaelmeudec/deblur-gan https://github.com/raphaelmeudec/deblur-gan

 

3、训练结果

 

从左到右:原始图像、模糊图像、GAN 输出

左:GOPRO 测试图像,右:GAN 输出

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。