Press "Enter" to skip to content

对抗生成网络GAN系列——EGBAD原理及缺陷检测实战

本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!

 

:tangerine:作者简介:秃头小苏,致力于用最通俗的语言描述问题

 

:tangerine:往期回顾: 对抗生成网络GAN系列——GAN原理及手写数字生成小案例 对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

 

:tangerine:近期目标:写好专栏的每一篇文章

 

:tangerine:支持小苏:点赞:+1|type_3:、收藏:star:、留言:envelope_with_arrow:

 

对抗生成网络GAN系列——EGBAD原理及缺陷检测实战

 

写在前面

 

​ 在上一篇,我为大家介绍了首次应用在缺陷检测中的GAN网络——ANoGAN。在文末总结了AnoGAN一个显而易见的劣势,即在测试阶段需要花费大量时间来搜索潜在变量z,这在很多应用场景中是难以接受的。本文针对上述所说缺点,介绍一种新的GAN网络——EGBAD,其在训练过程中通过一个巧妙的编码器实现对z的搜索,这样在测试过程中就可以节约大量时间。:blossom::blossom::blossom:

 

​ 阅读本文之前,建议先对AnoGAN有一定了解,可参考下文:

[1] 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战 :maple_leaf::maple_leaf::maple_leaf:

​ 如果你准备就绪的话,就让我们一起来学学AnoGAN的改进版EGBAD吧!!!:herb::herb::herb:

 

EGBAD原理详解:sparkles::sparkles::sparkles:

 

​ 一直在说EGBAD,大家肯定一脸懵,到底什幺才是EGBAD了?我们先来看看它的英文全称,即 EFFICIENT GAN-BASED ANOMALY DETECTION ,中文译为基于GAN的高效异常检测。通过说明EGBAD的字面含义,相信大家知道了EGBAD是用来干什幺的了。没错,它也是用于缺陷检测的网络,是对AnoGAN的优化。至于具体是怎幺优化的,且听下文分解。:mushroom::mushroom::mushroom:

 

​ 我们先来回顾一下AnoGAN是怎幺设计的?AnoGAN分为训练和测试两个阶段进行,训练阶段使用正常数据训练一个DCGAN网络,在测试阶段,固定训练阶段的网络权重,不断更新潜在变量z,使得由z生成的假图像尽可能接近真实图片。 【如果你对这个过程不熟悉的话,建议看看[1]中内容喔】 在介绍EGBAD是怎幺设计的前,我们先来看看EGBAD主要解决了AnoGAN什幺问题?其实这点我在 写在前面 已经提及,AnoGAN在测试阶段要不断搜索潜在变量z,这消耗了大量时间,EGBAD的提出就是为了解决AnoGAN时间消耗大的问题。接着我们来就来看看EGBAD具体是怎幺做的呢?EGBAD也分为训练和测试两个阶段进行。在训练阶段,不仅要训练生成器和判别器,还会定义一个编码器(encoder)结构并对其训练,encoder主要用于将输入图像通过网络转变成一个潜在变量。在测试阶段,冻结训练阶段的所以权重,之后通过encoder将输入图像变为潜在变量,最后在将潜在变量送入生成器,生成假图像。可以发现EGBAD没有在测试阶段搜索潜在变量,而是直接通过一个encoder结构将输入图像转变成潜在变量,这大大节省了时间成本。

 

​ 关于EGBAD训练过程模型示意图如下: 【测试过程很简单啦,就不介绍了】

 

​ 可以看出判别器的输入有两个,一个是生成器生成的假图像

x

{\rm{x’}}
x ′ ,另一个是编码器生成的

z

{\rm{z’}}
z ′ 。具体生成器、编码器和判别器的结构如何,将在下章代码实战中介绍。:palm_tree::palm_tree::palm_tree:

 

EGBAD代码实战

 

代码下载地址:sparkles::sparkles::sparkles:

 

​ 同样,我将此部分的源码上传到Github上了,大家可以阅读README文件了解代码的使用,Github地址如下:

 

EGBAD-pytorch实现

 

我认为你阅读README文件后已经对这个项目的结构有所了解,我在下文也会帮大家分析分析源码,但更多的时间大家应该自己动手去亲自调试,这样你会有不一样的收获。:ear_of_rice::ear_of_rice::ear_of_rice:

 

数据读取

 

​ 这部分和AnoGAN中完全一致,就不带大家一行行看调试结果了,不明白的可以阅读AnoGAN教程,这里直接上代码:

 

#导入相关包
import numpy as np
import pandas as pd
"""
mnist数据集读取
"""
## 读取训练集数据  (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 读取测试集数据  (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)
# 查询训练数据中标签为7、8的数据,并取前400个
train = train.query("label in [7.0, 8.0]").head(400)
# 查询训练数据中标签为7、8的数据,并取前400个
test = test.query("label in [2.0, 7.0, 8.0]").head(600)
# 取除标签后的784列数据
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')
# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)

 

模型搭建

 

​ 这部分大家就潜心修行,慢慢调试代码吧,我也会给出每个模型的结构图辅助大家,就让我们一起来看看吧☘☘☘

 

生成模型搭建

 

"""定义生成器网络结构"""
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
        seq = []
        seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
        if bn is True:
          seq += [nn.BatchNorm2d(out_channel)]
        seq += [activation]
        return nn.Sequential(*seq)
    seq = []
    seq += [CBA(20, 64*8, stride=1, padding=0)]
    seq += [CBA(64*8, 64*4)]
    seq += [CBA(64*4, 64*2)]
    seq += [CBA(64*2, 64)]
    seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
    self.generator_network = nn.Sequential(*seq)
  def forward(self, z):
      out = self.generator_network(z)
      return out

 

​ 生成模型的搭建其实很AnoGAN是完全一样的,我也给出生成网络的结构图,如下:

 

编码器模型搭建

 

"""定义编码器结构"""
class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()
    def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
        seq = []
        seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
        seq += [nn.BatchNorm2d(out_channel)]
        seq += [activation]
        return nn.Sequential(*seq)
    seq = []
    seq += [CBA(1, 64)]
    seq += [CBA(64, 64*2)]
    seq += [CBA(64*2, 64*4)]
    seq += [CBA(64*4, 64*8)]
    seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
    self.feature_network = nn.Sequential(*seq)
    self.embedding_network = nn.Linear(512, 20)
  def forward(self, x):
    feature = self.feature_network(x).view(-1, 512)
    z = self.embedding_network(feature)
    return z

 

​ 这部分其实也很简单,就是一系列卷积的堆积,编码器的结构图如下:

 

判别模型搭建

 

"""定义判别器网络结构"""
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
        seq = []
        seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
        seq += [nn.BatchNorm2d(out_channel)]
        seq += [activation]
        return nn.Sequential(*seq)
    seq = []
    seq += [CBA(1, 64)]
    seq += [CBA(64, 64*2)]
    seq += [CBA(64*2, 64*4)]
    seq += [CBA(64*4, 64*8)]
    seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
    self.feature_network = nn.Sequential(*seq)
    seq = []
    seq += [nn.Linear(20, 512)]
    seq += [nn.BatchNorm1d(512)]
    seq += [nn.LeakyReLU(0.1, inplace=True)]
    self.latent_network = nn.Sequential(*seq)
    self.critic_network = nn.Linear(1024, 1)
  def forward(self, x, z):
      feature = self.feature_network(x)
      feature = feature.view(feature.size(0), -1)
      latent = self.latent_network(z)
      out = self.critic_network(torch.cat([feature, latent], dim=1))
      return out, feature

 

​ 虽然判别器有两个输入,两个输出,但是结构也非常清晰,如下图所示:

 

​ 在模型搭建部分我还想提一点我们需要注意的地方,一般我们设计好一个网络结构后,我们往往会先设计一个tensor来作为网络的输入,看看网络输出是否是是我们预期的,如果是,我们再进行下一步,否则我们需要调整我们的结构以适应我们的输入。通常情况下,tensor的batch维度设为1就行,但是这里设置成1就会报错,提示我们需要设置一个batch大于1的整数,当将batch设置为2时,程序正常,至于产生这种现象的原因我目前也不是很清楚,大家注意一下,知道的也烦请告知一下。关于调试网络结构是否正常的代码如下,仅供参考:

 

if __name__ == '__main__':
    x = torch.ones((2, 1, 64, 64))
    z = torch.ones((2, 20, 1, 1))
    Generator = Generator()
    Discriminator = Discriminator()
    encoder = encoder()
    output_G = Generator(z)
    output_D1, output_D2= Discriminator(x, z.view(2, -1))
    output_E = encoder(x)
    print(output_G.shape)
    print(output_D1.shape)
    print(output_D2.shape)
    print(output_E.shape)

 

模型训练

 

数据集加载

 

​ 这部分和AnoGAN一致,注意最终输入网络的图片尺寸都上采样成了64×64.

 

class image_data_set(Dataset):
    def __init__(self, data):
        self.images = data[:,:,:,None]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        return self.transform(self.images[idx])
        
 # 加载训练数据
 train_set = image_data_set(train)
 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

 

加载模型、定义优化器、损失函数等参数

 

​ 这部分也基本和AnoGAN类似,只不过添加了encoder网络的定义和优化器定义部分,如下:

 

# 指定设备
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size默认128
batch_size = args.batch_size
# 加载模型
G = Generator().to(device)
D = Discriminator().to(device)
E = Encoder().to(device)
# 训练模式
G.train()
D.train()
E.train()
# 设置优化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerE = torch.optim.Adam(E.parameters(), lr=0.0004, betas=(0.0,0.9))
# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')

 

训练GAN网络

 

"""
训练
"""
# 开始训练
for epoch in range(args.epochs):
    # 定义初始损失
    log_g_loss, log_d_loss, log_e_loss = 0.0, 0.0, 0.0
    for images in train_loader:
        images = images.to(device)
        ## 训练判别器 Discriminator
        # 定义真标签(全1)和假标签(全0)   维度:(batch_size)
        label_real = torch.full((images.size(0),), 1.0).to(device)
        label_fake = torch.full((images.size(0),), 0.0).to(device)
        # 定义潜在变量z    维度:(batch_size,20,1,1)
        z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
        # 潜在变量喂入生成网络--->fake_images:(batch_size,1,64,64)
        fake_images = G(z)
        # 使用编码器将真实图像变成潜在变量   image:(batch_size, 1, 64, 64)-->z_real:(batch_size, 20)
        z_real = E(images)
        # 真图像和假图像送入判别网络,得到d_out_real、d_out_fake   维度:都为(batch_size,1)
        d_out_real, _ = D(images, z_real)
        d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
        # 损失计算
        d_loss_real = criterion(d_out_real.view(-1), label_real)
        d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
        d_loss = d_loss_real + d_loss_fake
        # 误差反向传播,更新损失
        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()
        ## 训练生成器 Generator
        # 定义潜在变量z    维度:(batch_size,20,1,1)
        z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
        fake_images = G(z)
        # 假图像喂入判别器,得到d_out_fake   维度:(batch_size,1)
        d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
        # 损失计算
        g_loss = criterion(d_out_fake.view(-1), label_real)
        # 误差反向传播,更新损失
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()
        ## 训练编码器Encode
        # 使用编码器将真实图像变成潜在变量    image:(batch_size, 1, 64, 64)-->z_real:(batch_size, 20)
        z_real = E(images)
        # 真图像送入判别器,记录结果d_out_real:(128, 1)
        d_out_real, _ = D(images, z_real)
        # 损失计算
        e_loss = criterion(d_out_real.view(-1), label_fake)
        # 误差反向传播,更新损失
        optimizerE.zero_grad()
        e_loss.backward()
        optimizerE.step()
        ## 累计一个epoch的损失,判别器损失、生成器损失、编码器损失分别存放到log_d_loss、log_g_loss、log_e_loss中
        log_d_loss += d_loss.item()
        log_g_loss += g_loss.item()
        log_e_loss += e_loss.item()
    ## 打印损失
    print(f'epoch {epoch}, D_Loss:{log_d_loss/128:.4f}, G_Loss:{log_g_loss/128:.4f}, E_Loss:{log_e_loss/128:.4f}')

 

这里总结一下上述训练的步骤,不断循环下列过程:

 

 

    1. 使用生成器从潜在变量z中创建假图像

 

    1. 使用编码器从真实图像中创建潜在变量

 

    1. 生成器和编码器结果送入判别器,进行训练

 

    1. 使用生成器从潜在变量z中创建假图像

 

    1. 训练生成器

 

    1. 使用编码器从真实图像中创建潜在变量

 

    1. 训练编码器

 

 

关于第3步,我也简单画了个图帮大家理解下,如下:

 

​ 最后我们来展示一下生成图片的效果,如下图所示:

 

 

缺陷检测

 

​ EGBAD缺陷检测非常简单,首先定义一个就算损失的函数,如下:

 

## 定义缺陷计算的得分
def anomaly_score(input_image, fake_image, z_real, D):
    # Residual loss 计算
    residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))
    # Discrimination loss 计算
    _, real_feature = D(input_image, z_real)
    _, fake_feature = D(fake_image, z_real)
    discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))
    # 结合Residual loss和Discrimination loss计算每张图像的损失
    total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
    return total_loss_by_image

 

​ 接着我们只需要用Encoder网络生成潜在变量,在再用生成器即可得到假图像,最后计算假图像和真图像的损失即可,如下:

 

# 加载测试数据
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)
# 通过编码器获取潜在变量,并用生成器生成假图像
z_real = E(input_images)
fake_images = G(z_real.view(input_images.size(0), 20, 1, 1))
# 异常计算
anomality = anomaly_score(input_images, fake_images, z_real, D)
print(anomality.cpu().detach().numpy())

 

​ 最后可以保存一下真实图像和假图像的结果,如下:

 

torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")

 

​ 我们来看一下结果:

 

​ 通过上图你发现了什幺呢?是不是发现输入图像为7的图片的生成图像不是7而变成了8呢,究其原因,应该是生成器学到了更多关于数据8的特征,也就是说这个网络的生成效果并没有很好。

 

​ 我做了很多实验,发现EGBAD虽然测试时间上比AnoGAN快很多,但是它的稳定性似乎并没有很理想,很容易出现模式崩溃的问题。其实啊,GAN网络普遍存在着训练不稳定的现象,这也是一些大牛不断探索的方向,后面的文章我也会给大家介绍一些增加GAN训练稳定性的文章,敬请期待吧!:mushroom::mushroom::mushroom:

 

AnoGAN和EGBAD测试时间对比:sparkles::sparkles::sparkles:

 

​ 我们一直说EGBAD的测试时间相较AnoGAN短,从原理上来说确实是这样,但是具体是不是这样我们还要以实验为准。测试代码也很简单,只需要在测试过程中使用 time.time() 函数即可,具体可以参考我上传github中的源码,这里给出我测试两种网络在测试阶段所用时间(以秒为单位),如下图所示:

 

 

​ 通过上图数据可以看出,EGBAD比AnoGAN快的不是一点点,EGBAD的速度将近是AnoGAN的10000倍,这个数字还是很恐怖的。:custard::custard::custard:

 

总结

 

​ 到此,EGBAD的全部内容就为大家介绍完了,如果你明白了AnoGAN的话,这篇文章对你来说应该是小菜一碟了。EGBAD大大的减少了测试所有时间,但是GAN网络普遍存在易模式崩溃、训练不稳定的现象, 下一篇博文我将为大家介绍一些让GAN训练更稳定的技巧,敬请期待吧。:rice::rice::rice:

 

参考链接

 

EFFICIENT GAN-BASED ANOMALY DETECTION :maple_leaf::maple_leaf::maple_leaf:

 

GAN 使用 Pytorch 进行异常检测的方法 :maple_leaf::maple_leaf::maple_leaf:

 

如若文章对你有所帮助,那就

 

咻咻咻咻~~duang~~点个赞呗

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注