Press "Enter" to skip to content

原力计划 深度卷积生成对抗网络DCGAN——生成手写数字图片 前言本文使用深度卷积生成对抗网络(DCGA…

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

前言

 

本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。

 

本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:

 

# 用于生成 GIF 图片
pip install -q imageio

 

 

一、什幺是生成对抗网络?

 

四、定义损失函数和优化器

 

4.1 生成器的损失和优化器

 

4.2 判别器的损失和优化器

 

一、什幺是生成对抗网络?

 

生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。

 

生成器 ,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。

 

判别器 ,可以理解为“艺术评论家、审核者”,它学习区分真假图像。

 

训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。

 

当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。

 

本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。

 

二、加载数据集

 

使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。

 

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# 批量化和打乱数据
train_ = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

 

三、创建模型

 

主要创建两个模型,一个是 生成器 ,另一个是 判别器 。

 

3.1 生成器

 

生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。

 

然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。

 

后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。

 

def make_generator_():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
    return model

 

用tf.keras.utils.plot_model( ),看一下模型结构

 

 

用summary(),看一下模型结构和参数

 

 

使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。

 

generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')

 

 

3.1 判别器

 

判别器是基于 CNN卷积神经网络 的图片分类器。

 

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model

 

用tf.keras.utils.plot_model( ),看一下模型结构

 

 

用summary(),看一下模型结构和参数

 

 

四、定义损失函数和优化器

 

由于有两个模型,一个是 生成器 ,另一个是 判别器; 所以要分别为两个模型定义损失函数和优化器。

 

首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。

 

# 该方法返回计算交叉熵损失的辅助函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

 

4.1 生成器的损失和优化器

 

1)生成器损失

 

生成器损失 , 是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。

 

这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。

 

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

 

2)生成器优化器

 

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

 

4.2 判别器的损失和优化器

 

1)判别器损失

 

判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。

 

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

 

2)判别器优化器

 

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

 

五、训练模型

 

5.1 保存检查点

 

保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。

 

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

 

5.2 定义训练过程

 

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

 

训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。

 

判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。

 

两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。

 

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as _tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()
    for image_batch in dataset:
      train_step(image_batch)
    # 继续进行时为 GIF 生成图像
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)
    # 每 15 个 epoch 保存一次模型
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
  # 最后一个 epoch 结束后生成图片
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):
  # 注意 training` 设定为 False
  # 因此,所有层都在推理模式下运行(batchnorm)。
  predictions = model(test_input, training=False)
  fig = plt.figure(figsize=(4,4))
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

 

5.3 训练模型

 

调用上面定义的train()函数,来同时训练生成器和判别器。

 

注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。

 

%%time
train(train_dataset, EPOCHS)

 

在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。

 

训练了15轮的效果:

 

 

训练了30轮的效果:

 

 

训练过程:

 

 

恢复最新的检查点

 

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

 

六、评估模型

 

这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。

 

# 使用 epoch 数生成单张图片
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

 

anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import IPython
if IPython.version_info > (6,2,0,''):
  display.Image(filename=anim_file)

 

 

完整代码:

 

import  as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
 
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
 
BUFFER_SIZE = 60000
BATCH_SIZE = 256
 
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# 创建模型--生成器
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
 
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
 
    return model
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
generator = make_generator_model()
 
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
 
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
tf.keras.utils.plot_model(generator)
# 判别器
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
 
    return model
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 生成器的损失和优化器
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
# 判别器的损失和优化器
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# 保存检查点
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
# 定义训练过程
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
 
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
 
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
 
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
 
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
 
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
 
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()
 
    for image_batch in dataset:
      train_step(image_batch)
 
    # 继续进行时为 GIF 生成图像
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)
 
    # 每 15 个 epoch 保存一次模型
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
 
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
 
  # 最后一个 epoch 结束后生成图片
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
 
# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):
  # 注意 training` 设定为 False
  # 因此,所有层都在推理模式下运行(batchnorm)。
  predictions = model(test_input, training=False)
 
  fig = plt.figure(figsize=(4,4))
 
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
 
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
# 训练模型
train(train_dataset, EPOCHS)
# 恢复最新的检查点
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# 评估模型
# 使用 epoch 数生成单张图片
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
 
display_image(EPOCHS)
anim_file = 'dcgan.gif'
 
with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
 
import IPython
if IPython.version_info > (6,2,0,''):
  display.Image(filename=anim_file)

 

参考: https://www.tensorflow.org/tutorials/generative/dcgan

 

一篇文章“简单”认识《生成对抗网络》(GAN)

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注