## 一、GAN对抗生成神经网络简介

GAN的用处很广，可以生成虚假图像、文本等数据，当模型训练的数据量很少的时候，也同样可以利用GAN生成数据进行训练，所以GAN也是一种数据增强的方式，可以提高模型的鲁棒性。

## 三、判别器D和生成器G的代码实现

### 3.1 判别器D

```#判别器
class CNN_Discriminator(nn.Module):
def __init__(self):
super(CNN_Discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),  # batch, 32, 96，96,
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2),  # batch, 32, 48, 48
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 48, 48
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=3)  # batch, 64, 16, 16
)
self.fc = nn.Sequential(
nn.Linear(64 * 16 * 16, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):#输入为3通道的96x96大小的图像数据矩阵
'''
x: batch, width, height, channel=3
'''
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x```

### 3.2 生成器G

```#生成器
class CNN_Generator(nn.Module):
def __init__(self):
super(CNN_Generator, self).__init__()
self.br = nn.Sequential(
nn.BatchNorm2d(15),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(15, 50, 3, stride=1, padding=1),  # batch, 50, 192, 192
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 192, 192
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25, 3, 2, stride=2),  # batch, 3, 96, 96
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)#经过线性映射，得到一个很大的一维输出数据
x = x.view(x.size(0), 15, 192, 192)#把一维数据转换成15个通道的192x192大小的矩阵数据
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)#卷积操作最后输出一个3通道的96x96大小的数据矩阵
return x```

## 四、原始模型的模型与优化

### 4.1 导入判别器和生成器

```D = Model.CNN_Discriminator() #加载判别器
D.to(args.device)
G = Model.CNN_Generator(args.z_dimension,15*192*192)#加载生成器
G.to(args.device)#把模型放到对应的显卡设备上

### 4.2 开始按原始论文方式训练

```for epoch in range(args.num_epoch):
if torch.cuda.is_available(): #清空显卡缓存
torch.cuda.empty_cache()
num_img = img.size(0)
#train discriminator
# compute loss of real_matched_img
img = img.view(num_img,3,96,96)
real_img = Variable(img).to(args.device)

#----------------------------训练判定器--------------------------------
matched_real_out = -1.0 * torch.log(D(real_img).squeeze(-1).sum())
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, args.z_dimension)).to(args.device)
fake_img = G(z)
matched_fake_out = -1.0 * torch.log((1.0 - D(fake_img).squeeze(-1)).sum())
# bp and optimize
d_loss = matched_real_out + matched_fake_out
d_loss.backward()
d_optimizer.step()
# ============================train generator================================
# compute loss of fake_img
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, args.z_dimension)).to(args.device)
fake_img = G(z)
matched_fake_out = torch.log((1.0 - D(fake_img).squeeze(-1)).sum())
g_loss = matched_fake_out
# bp and optimize
g_loss.backward()
g_optimizer.step()
print('Epoch [{}/{}], Batch {},d_loss: {:.6f}, g_loss: {:.6f} '
.format(
epoch, args.num_epoch,i,d_loss.data, g_loss.data,
))```

### 4.4 利用交叉熵损失函数代替原论文损失函数

```D = Model.CNN_Discriminator()
D.to(args.device)
G = Model.CNN_Generator(args.z_dimension,15*192*192)#加载生成器
G.to(args.device)#把模型放到对应的设备上
criterion = nn.BCELoss()#定义二分类交叉熵损失函数

```for epoch in range(args.num_epoch):
if torch.cuda.is_available(): #清空显卡缓存
torch.cuda.empty_cache()
num_img = img.size(0)
#train discriminator
# compute loss of real_matched_img
img = img.view(num_img,3,96,96)
real_img = Variable(img).to(args.device)
real_label = Variable(torch.ones(num_img)).to(args.device)
fake_label = Variable(torch.zeros(num_img)).to(args.device)

#----------------------------训练判定器--------------------------------
matched_real_out = D(real_img)
#matched_real_out = -1.0 * torch.log(D(real_img).squeeze(-1).sum())
d_loss_matched_real = criterion(matched_real_out.squeeze(-1), real_label)
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, args.z_dimension)).to(args.device)
fake_img = G(z)
matched_fake_out = D(fake_img)
#matched_fake_out = -1.0 * torch.log((1.0 - D(fake_img).squeeze(-1)).sum())
d_loss_matched_fake = criterion(matched_fake_out.squeeze(-1), fake_label)
# bp and optimize
#d_loss = matched_real_out + matched_fake_out
d_loss = d_loss_matched_real + d_loss_matched_fake
d_loss.backward()
d_optimizer.step()
# ============================train generator================================
# compute loss of fake_img
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, args.z_dimension)).to(args.device)
fake_img = G(z)
matched_fake_out =  D(fake_img)
#matched_fake_out = torch.log((1.0 - D(fake_img).squeeze(-1)).sum())
#matched_fake_out_scores = matched_fake_out
#g_loss = matched_fake_out
g_loss = criterion(matched_fake_out.squeeze(-1),real_label)
# bp and optimize
g_loss.backward()
g_optimizer.step()```

## 五、总结

1、尽量采用交叉熵损失函数，训练效果较好。
2、尽量保证生成模型G和判别模型D的复杂度一致，避免导致某个模型被另外一个模型单方面碾压的情况，这样无法有效形成对抗训练的过程。