本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.
一、GAN概述
在生成对抗网络中,两个网络相互训练。生成器通过创建虚假输入来误导鉴别器。鉴别器告诉输入是真实的还是人造的。
GAN 训练过程有3个主要步骤
1、使用生成器 根据噪声创建虚假输入
2、用真假输入训练鉴别器
3、训练整个模型:模型是用链接到生成器的鉴别器构建的。
鉴别器的权重在第三步中被冻结。链接两个网络的原因是生成器的输出没有可能的反馈。 我们唯一的衡量标准是鉴别器是否接受了生成的样本。
二、数据集
Ian Goodfellow 首先应用 GAN 模型来生成 MNIST 数据。在本教程中,我们使用生成对抗网络进行图像去模糊。因此,生成器的输入不是噪声而是模糊图像。
数据集是 GOPRO 数据集。您可以下载精简版 (9GB) 或 完整版 (35GB)。它包含 来自多个街景的人为模糊图像。数据集按场景分解到子文件夹中。
精简版 https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing 完整版 https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing 精简版百度云下载
链接:https://pan.baidu.com/s/1fS6QIoRhpngyGwn4eHOJDw 提取码:hky7
我们首先将图像分配到两个文件夹 A(模糊)和 B(清晰)中。此 A&B 架构对应于原始 pix2pix 文章。
Image-to-Image Translation with Conditional Adversarial Networks https://phillipi.github.io/pix2pix/
三、网络模型
1、生成器
生成器旨在再现清晰的图像。该网络基于ResNet 块。它跟踪应用于原始模糊图像的演变。上一节提及的出版物还使用了基于UNet 的版本。两个块都应该在图像去模糊方面表现良好。
核心是用于对原始图像进行上采样的9个ResNet块。
from keras.layers import Input, Conv2D, Activation, BatchNormalization from keras.layers.merge import Add from keras.layers.core import Dropout def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False): """ Instanciate a Keras Resnet Block using sequential API. :param input: Input tensor :param filters: Number of filters to use :param kernel_size: Shape of the kernel for the convolution :param strides: Shape of the strides for the convolution :param use_dropout: Boolean value to determine the use of dropout :return: Keras Model """ x = ReflectionPadding2D((1,1))(input) x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,)(x) x = BatchNormalization()(x) x = Activation('relu')(x) if use_dropout: x = Dropout(0.5)(x) x = ReflectionPadding2D((1,1))(x) x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,)(x) x = BatchNormalization()(x) # Two convolution layers followed by a direct connection between input and output merged = Add()([input, x]) return merged
这个 ResNet 层基本上是一个卷积层,输入和输出相加形成最终输出。
from keras.layers import Input, Activation, Add from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import Conv2D, Conv2DTranspose from keras.layers.core import Lambda from keras.layers.normalization import BatchNormalization from keras.models import Model from layer_utils import ReflectionPadding2D, res_block ngf = 64 input_nc = 3 output_nc = 3 input_shape_generator = (256, 256, input_nc) n_blocks_gen = 9 def generator_model(): """Build generator architecture.""" # Current version : ResNet block inputs = Input(shape=image_shape) x = ReflectionPadding2D((3, 3))(inputs) x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Increase filter number n_downsampling = 2 for i in range(n_downsampling): mult = 2**i x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x) x = BatchNormalization()(x) x = Activation('relu')(x) # Apply 9 ResNet blocks mult = 2**n_downsampling for i in range(n_blocks_gen): x = res_block(x, ngf*mult, use_dropout=True) # Decrease filter number to 3 (RGB) for i in range(n_downsampling): mult = 2**(n_downsampling - i) x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = ReflectionPadding2D((3,3))(x) x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x) x = Activation('tanh')(x) # Add direct connection from input to output and recenter to [-1, 1] outputs = Add()([x, inputs]) outputs = Lambda(lambda z: z/2)(outputs) model = Model(inputs=inputs, outputs=outputs, name='Generator') return model
将9个ResNet 块应用于输入的上采样版本。我们添加从输入到输出的连接并除以2以保持标准化输出。这就是生成器。
2、 鉴别器
目的是确定输入图像是否是人工创建的。因此,判别器的架构是卷积的, 输出单个值。
from keras.layers import Input from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import Conv2D from keras.layers.core import Dense, Flatten from keras.layers.normalization import BatchNormalization from keras.models import Model ndf = 64 output_nc = 3 input_shape_discriminator = (256, 256, output_nc) def discriminator_model(): """Build discriminator architecture.""" n_layers, use_sigmoid = 3, False inputs = Input(shape=input_shape_discriminator) x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs) x = LeakyReLU(0.2)(x) nf_mult, nf_mult_prev = 1, 1 for n in range(n_layers): nf_mult_prev, nf_mult = nf_mult, min(2**n, 8) x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8) x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x) if use_sigmoid: x = Activation('sigmoid')(x) x = Flatten()(x) x = Dense(1024, activation='tanh')(x) x = Dense(1, activation='sigmoid')(x) model = Model(inputs=inputs, outputs=x, name='Discriminator') return model
最后一步是构建完整的模型。这个GAN的一个特点是输入是真实图像而不是噪声。因此,我们 对生成器的输出有直接反馈。
from keras.layers import Input from keras.models import Model def generator_containing_discriminator_multiple_outputs(generator, discriminator): inputs = Input(shape=image_shape) generated_images = generator(inputs) outputs = discriminator(generated_images) model = Model(inputs=inputs, outputs=[generated_images, outputs]) return model
四、训练
1、损失函数
我们在生成器末端和完整模型末端两个级别提取损失。
第一个是直接在生成器的输出上计算的感知损失。第一个损失确保了 GAN 模型面向去模糊任务。它比较 VGG 的第一个卷积的输出。
import keras.backend as K from keras.applications.vgg16 import VGG16 from keras.models import Model image_shape = (256, 256, 3) def perceptual_loss(y_true, y_pred): vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape) loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output) loss_model.trainable = False return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
第二个损失是对整个模型的输出执行的Wasserstein 损失。它取两个图像之间差异的平均值。众所周知,它可以提高生成对抗网络的收敛性。
import keras.backend as K def wasserstein_loss(y_true, y_pred): return K.mean(y_true*y_pred)
2、训练程序
第一步是加载数据并初始化所有模型。我们使用自定义函数来加载数据集,并为我们的模型添加 Adam 优化器。我们设置 Keras 可训练选项以防止判别器训练。
# Load dataset data = load_images('./images/train', n_images) y_train, x_train = data['B'], data['A'] # Initialize models g = generator_model() d = discriminator_model() d_on_g = generator_containing_discriminator_multiple_outputs(g, d) # Initialize optimizers g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # Compile models d.trainable = True d.compile(optimizer=d_opt, loss=wasserstein_loss) d.trainable = False loss = [perceptual_loss, wasserstein_loss] loss_weights = [100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True
然后,我们开始启动epoch并将数据集分成批次。
for epoch in range(epoch_num): print('epoch: {}/{}'.format(epoch, epoch_num)) print('batches: {}'.format(x_train.shape[0] / batch_size)) # Randomize images into batches permutated_indexes = np.random.permutation(x_train.shape[0]) for index in range(int(x_train.shape[0] / batch_size)): batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size] image_blur_batch = x_train[batch_indexes] image_full_batch = y_train[batch_indexes]
最后,我们基于这两个损失依次训练鉴别器和生成器。我们使用生成器生成虚假输入。我们训练鉴别器以区分假输入和真实输入,并训练整个模型。
for epoch in range(epoch_num): for index in range(batches): # [Batch Preparation] # Generate fake inputs generated_images = g.predict(x=image_blur_batch, batch_size=batch_size) # Train multiple times discriminator on real and fake inputs for _ in range(critic_updates): d_loss_real = d.train_on_batch(image_full_batch, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d.trainable = False # Train generator only on discriminator's decision and generated images d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch]) d.trainable = True
完整代码参考下面链接
https://github.com/raphaelmeudec/deblur-gan https://github.com/raphaelmeudec/deblur-gan
3、训练结果
从左到右:原始图像、模糊图像、GAN 输出
左:GOPRO 测试图像,右:GAN 输出
Be First to Comment