1. GAN概念理解

## 三.示例代码解读

### 3.1关于数据集的下载

```You may replace the workspace directory if you want.
workspace_dir = '.'
Training progress bar
!pip install -q qqdm
!gdown --id 1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p --output "{workspace_dir}/crypko_data.zip"```

### 3.2导入相关包和函数

```import random
import torch
import numpy as np
import os
import glob
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
import matplotlib.pyplot as plt
from qqdm.notebook import qqdm```

### 3.3DateSet 数据预处理

Transfrom：

1.transforms.Compose()：将一系列的transforms有序组合，实现时按照这些方法依次对图像操作。

2.transforms.ToPILImage：将数据转换为PILImage。

3.transforms.Resize：图像变换

4.transforms.ToTensor：转为tensor，并归一化至[0-1]

5.transforms.Normalize：数据归一化处理

mean:各通道的均值
std：各通道的标准差

6.主函数进行数据加载：

```workspace_dir='D://机器学习//Jupyter//GAN学习//函数'
dataset = get_dataset(os.path.join(workspace_dir, 'faces'))```

### 3.4Model-模型的建立-DCGAN

#### 3.4.1权重初始化

DCGAN指出，所有的权重都以均值为0，标准差为0.2的正态分布随机初始化。weights_init 函数读取一个已初始化的模型并重新初始化卷积层，转置卷积层，batch normalization 层。这个函数在模型初始化之后使用。

```def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)```

#### 下面为生成器模型分析

```def dconv_bn_relu(in_dim, out_dim):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
nn.BatchNorm2d(out_dim),
nn.ReLU()
)```

```self.l1 = nn.Sequential(
nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
nn.BatchNorm1d(dim * 8 * 4 * 4),
nn.ReLU()
)```

```self.l2_5 = nn.Sequential(
dconv_bn_relu(dim * 8, dim * 4),
dconv_bn_relu(dim * 4, dim * 2),
dconv_bn_relu(dim * 2, dim),
nn.Tanh()
)```

`self.apply(weights_init)`

```def forward(self, x):
y = self.l1(x)
y = y.view(y.size(0), -1, 4, 4)
y = self.l2_5(y)
return y```

```netG=Generator(100)
print(netG)```

```(l1): Sequential(
(0): Linear(in_features=100, out_features=8192, bias=False)
(1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(l2_5): Sequential(
(0): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(1): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(2): Sequential(
(0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(4): Tanh()
)
)```

#### 3.4.3 Discriminator-判别器模型

```from torch import nn
from 函数.weights_inition import weights_init
class Discriminator(nn.Module):
"""
Input shape: (N, 3, 64, 64)
Output shape: (N, )
"""
def __init__(self, in_dim, dim=64):
super(Discriminator, self).__init__()
def conv_bn_lrelu(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 5, 2, 2),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2),
)
""" Medium: Remove the last sigmoid layer for WGAN. """
self.ls = nn.Sequential(
nn.Conv2d(in_dim, dim, 5, 2, 2),
nn.LeakyReLU(0.2),
conv_bn_lrelu(dim, dim * 2),
conv_bn_lrelu(dim * 2, dim * 4),
conv_bn_lrelu(dim * 4, dim * 8),
nn.Conv2d(dim * 8, 1, 4),
nn.Sigmoid(),
)
self.apply(weights_init)
def forward(self, x):
y = self.ls(x)
y = y.view(-1)
return y```

```Discriminator(
(ls): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2)
(2): Sequential(
(0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(3): Sequential(
(0): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(4): Sequential(
(0): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
(6): Sigmoid()
)
)```

### 3.5Training-模型的训练-DCGAN

#### 3.5.1 创建网络结构

```G = Generator(in_dim=z_dim).to(device)
D = Discriminator(3).to(device)
G.train()
D.train()
# Loss
criterion = nn.BCELoss()
# Optimizer
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))```

in_dim=z_dim=100,z的分布（高斯分布）深度为100

#### 3.5.2加载数据

`dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)`

z为随机生成64*100的高斯分布数据（均值为0，方差为1）也叫噪声。

z为生成器的输入。

#### 3.5.3 训练D(判别器)

z为随机生成64*100的高斯分布数据（均值为0，方差为1）也叫噪声。

z为生成器的输入。

`z = Variable(torch.randn(bs, z_dim)).to(device)`

f_imgs大小为 64 *3 *64 *64(生成64张假图片)

```r_imgs = Variable(imgs).to(device)
f_imgs = G(z)```

```r_label = torch.ones((bs)).to(device)
f_label = torch.zeros((bs)).to(device)```

```r_logit = D(r_imgs.detach())
f_logit = D(f_imgs.detach())```

```r_loss = criterion(r_logit, r_label)
f_loss = criterion(f_logit, f_label)
loss_D = (r_loss + f_loss) / 2```

```D.zero_grad()
loss_D.backward()
opt_D.step()```

#### 3.5.4 训练G(生成器)

z为随机生成64*100的高斯分布数据（均值为0，方差为1）也叫噪声。

z为生成器的输入。

`z = Variable(torch.randn(bs, z_dim)).to(device)`

```f_imgs = G(z)
f_logit = D(f_imgs)
loss_G = criterion(f_logit, r_label)```

```G.zero_grad()
loss_G.backward()
opt_G.step()```

### 3.6结果展示

```G.eval()
f_imgs_sample = (G(z_sample).data + 1) / 2.0
filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg')
torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
print(f' | Save some samples to {filename}.')
# Show generated images in the jupyter notebook.
grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
G.train()
if (e + 1) % 5 == 0 or e == 0:
# Save the checkpoints.
torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))```

bitch_size=10

bitch_size=64

### 3.7代码文件

import的文件名，还有数据的地址（在main函数中）改为自己的地址