Press "Enter" to skip to content

对抗网络(GAN)手写数字生成

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

目录

 

(2条消息) tensorflow零基础入门学习_重邮研究森的博客-CSDN博客_tensorflow 学习 https://blog.csdn.net/m0_60524373/article/details/124143223 https://blog.csdn.net/m0_60524373/article/details/124143223​>- 本文为[365天深度学习训练营](https://mp.weixin.qq.com/s/k-vYaC8l7uxX51WoypLkTw) 中的学习记录博客

 

>- 参考文章地址: (1条消息) 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天_K同学啊的博客-CSDN博客 https://mtyjkh.blog.csdn.net/article/details/118995896

 

本文开发环境:tensorflowgpu2.5,经过验证,2.4也可以运行

 

1.跑通代码

 

我这个人对于任何代码,我都会先去跑通之和才会去观看内容,哈哈哈,所以第一步我们先不管37=21,直接把博主的代码复制黏贴一份运行结果。(PS:做了一些修改,因为原文是jupyter,而我在pycharm)

 

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]], "GPU")
# 打印显卡信息,确认GPU可用
print(gpus)
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
img_shape  = (28, 28, 1)
latent_dim = 200
def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),  # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),  # BN 归一化
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])
    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)
    return Model(noise, img)
def build_discriminator():
    # ===================================== #
    #   鉴别器,对输入的图片进行判别真假
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    img = layers.Input(shape=img_shape)
    validity = model(img)
    return Model(img, validity)
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
# 创建生成器
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)
# 在训练generate的时候不训练discriminator
discriminator.trainable = False
# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def sample_images(epoch):
    """
    保存样例图片
    """
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)
    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)
   # fig.savefig(" E:/2021_Project_YanYiXia/AI/21/对抗网络(GAN)手写数字生成/images/%05d.png" % epoch)
    plt.close()
def train(epochs, batch_size=128, sample_interval=50):
    # 加载数据
    (train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
    # 将图片标准化到 [-1, 1] 区间内
    train_images = (train_images - 127.5) / 127.5
    # 数据
    train_images = np.expand_dims(train_images, axis=3)
    # 创建标签
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    # 进行循环训练
    for epoch in range(epochs):
        # 随机选择 batch_size 张图片
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]
        # 生成噪音
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        # 训练鉴别器
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # 返回loss值
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)
        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)
#train(epochs=30000, batch_size=256, sample_interval=200)
import imageio
def compose_gif():
    # 图片地址
    data_dir = "E:/2021_Project_YanYiXia/AI/21/对抗网络(GAN)手写数字生成/images"
    data_dir = pathlib.Path(data_dir)
    paths = list(data_dir.glob('*'))
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif", gif_images, fps=2)
compose_gif()

 

点击pycharm运行即可得到结果,此图为对抗网络生成的手写数字

 

 

2.代码分析

 

神经网络的整个过程我分为如下六部分,而我们也会对这六部分进行逐部分分析。那幺这6部分分别是: 但是 :这里是对抗网络,和传统方法有区别,因此六步法不适用了,我们将重新分析。

 

五步法:

 

1->import

 

2->设置生成器和判别器

 

3->创建生成器和判别器

 

4->训练模型

 

5->验证

 

2.1

 

导入:这里很容易理解,也就是导入本次实验内容所需要的各种库。在本案例中主要包括以下部分:

 

 

蓝框1:

 

设置电脑gpu工作,如果你的电脑没有gpu就不设置,或者你的gpu显存不够,训练时出问题了,那幺就设在为cpu模式

 

蓝框2:

 

导入各种库

 

对于这里的话我们可以直接复制黏贴,当需要一些其他函数时,只需要添加对应的库文件即可。

 

2.2

 

这里是设置生成器和判别器。其中对于生成器和判别器的详细解释可以参考下面这个链接。

 

四天搞懂生成对抗网络(一)——通俗理解经典GAN – 知乎 (zhihu.com) https://zhuanlan.zhihu.com/p/307527293 总之,我们需要清楚的一点是,对抗网络中:

 

生成器:根据随机数生成一些“以假乱真”的数据集

 

判别器:对生成器“以假乱真”的数据集和真实的数据集进行判别

 

两者在训练过程中都会不断进行优化,生成器会不断生成更多“更真”的数据,判别器会“检测”的更仔细。

 

下面进行详细代码解释:

 

 

蓝框 1:

 

这里设置我们的图片格式和输入的维度

 

蓝框2:

 

这里引入了alpha激活函数,批标准化,prod函数。

 

先设置网络层,然后把噪音当作输入层,img当作输出层。

 

 

这部分为判别器。

 

输入是之前噪音生成的img,输出是真或者假

 

很有趣的一点,生成器和判别器的网络模型基本上是对折的!

 

2.3

 

在生成器和判别器都定义好之后,我们可以创建它们

 

 

蓝框1:

 

设置优化器,关于优化器的参数设置可以参考文章开头我之前写的一篇基础文章

 

蓝框2:

 

创建生成器,生成器输入为噪音维度的数,输出是图片数据

 

蓝框3:

 

在训练生成器的时候不训练判别器

 

2.4

 

在完成基础准备工作之后,就可以开始训练了

 

 

重点!!!

 

现在我们来对如何对生成器和判别器训练进行代码解读!!!

 

蓝框1:

 

这里就是我们之前文章调用minist数据集制作datset的方法,包括加载数据,数据处理,归一化,修改维度。 同时这里的区别是 :在标签方面,1为真,0为假,都是根据batch_size来生产的一个列表。

 

蓝框2:

 

返回一个随机整型数,范围从低0(包括)到高 train_images.shape[0](不包括).另外输出随机数的尺寸为batch_size

 

总结:这里就是随机获取官方数据集中任意一组图片

 

蓝框3:

 

从正态(高斯)分布中抽取随机样本。其中样本尺寸为(batch_size, latent_dim)

 

根据噪音的随机情况可以生成随机的一个图片数据

 

蓝框4:

 

discriminator.train_on_batch(imgs, true)的意思是,输入为img,输出为true,返回结果为loss

 

把真实数据和假数据的loss计算出来为总loss

 

蓝框5

 

从正态(高斯)分布中抽取随机样本。其中样本尺寸为(batch_size, latent_dim)

 

根据噪音的随机情况然后利用 combined.train_on_batch(noise, true)的意思是,输入为noise,输出为true,返回结果为loss

 

蓝框6

 

打印每轮的损失函数信息

 

蓝框7

 

执行训练

 

2.5

 

训练结束后,我们可以观察训练的结果

 

 

上面是生成tif动图代码

 

​ 上面是保存样例图片代码

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。