01

GAN的原理

02

### 库

“””

Import necessary libraries to create a generative adversarial network

The code is mainly developed using the PyTorch library

import time

import torch

import torch.nn  as nn

import torch.optim  as optim

from torchvision  import datasets

from torchvision.transforms  import transforms

from model  import discriminator, generator

import numpy  as np

import

matplotlib.pyplot

as

plt

### 硬件需求

“””

Determine if any GPUs are available

device = torch.device(
‘cuda’
if torch.cuda.is_available()
else

‘cpu’

)

03

### 网络结构

“””

Network Architectures

The following are the discriminator and generator architectures

class discriminator (nn.Module) :

def __init__ (self) :

super(discriminator, self).__init__()

self.fc1 = nn.Linear( 784 ,  512 )

self.fc2 = nn.Linear( 512 ,  1 )

self.activation = nn.LeakyReLU( 0.1 )

def forward (self, x) :

x = x.view( -1 ,  784 )

x = self.activation(self.fc1(x))

x = self.fc2(x)

return nn.Sigmoid()(x)

class generator (nn.Module) :

def __init__ (self) :

super(generator, self).__init__()

self.fc1 = nn.Linear( 128 ,  1024 )

self.fc2 = nn.Linear( 1024 ,  2048 )

self.fc3 = nn.Linear( 2048 ,  784 )

self.activation = nn.ReLU()

def forward (self, x) :

x = self.activation(self.fc1(x))

x = self.activation(self.fc2(x))

x = self.fc3(x)

x = x.view( -1 ,  1 ,  28 ,  28 )

return

nn.Tanh()(x)

### 训练

“””

Network training procedure

Every step both the loss for disciminator and generator is updated

Discriminator aims to classify reals and fakes

Generator aims to generate images as realistic as possible

for epoch  in range(epochs):

for idx, (imgs, _)  in enumerate(train_loader):

idx +=  1

# Training the discriminator

# Real inputs are actual images of the MNIST dataset

# Fake inputs are from the generator

# Real inputs should be classified as 1 and fake as 0

real_inputs = imgs.to(device)

real_outputs = D(real_inputs)

real_label = torch.ones(real_inputs.shape[ 0 ],  1 ).to(device)

noise = (torch.rand(real_inputs.shape[ 0 ],  128 ) –  0.5 ) /  0.5

noise = noise.to(device)

fake_inputs = G(noise)

fake_outputs = D(fake_inputs)

fake_label = torch.zeros(fake_inputs.shape[ 0 ],  1 ).to(device)

outputs = torch.cat((real_outputs, fake_outputs),  0 )

targets = torch.cat((real_label, fake_label),  0 )

D_loss = loss(outputs, targets)

D_loss.backward()

D_optimizer.step()

# Training the generator

# For generator, goal is to make the discriminator believe everything is 1

noise = (torch.rand(real_inputs.shape[ 0 ],  128 ) -0.5 )/ 0.5

noise = noise.to(device)

fake_inputs = G(noise)

fake_outputs = D(fake_inputs)

fake_targets = torch.ones([fake_inputs.shape[ 0 ],  1 ]).to(device)

G_loss = loss(fake_outputs, fake_targets)

G_loss.backward()

G_optimizer.step()

if idx %  100 ==  0 or idx == len(train_loader):

print( ‘Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}’ .format(epoch, idx, D_loss.item(), G_loss.item()))

if (epoch+ 1 ) %  10 ==  0 :

torch.save(G,  ‘Generator_epoch_{}.pth’ .format(epoch))

print(

‘Model saved.’

)

04

05

GAN和以往机器视觉专家提出的想法都不一样，而利用GAN进行的具体场景应用更是让许多人赞叹深度网络的无限潜力。下面我们来看一下两个最为出名的GAN延申应用。

06