## 对抗生成网络的实现

### 反卷积层 (ConvTranspose2d)

```>>> import torch
# 生成测试用的矩阵
# 第一个维度代表批次，第二个维度代表通道数量，第三个维度代表长度，第四个维度代表宽度
>>> a = torch.arange(1, 5).float().reshape(1, 1, 2, 2)
>>> a
tensor([[[[1., 2.],
[3., 4.]]]])
# 创建反卷积层
>>> convtranspose2d = torch.nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
# 手动指定权重 (让计算更好理解)
>>> convtranspose2d.weight = torch.nn.Parameter(torch.tensor([0.1, 0.2, 0.5, 0.8]).reshape(1, 1, 2, 2))
>>> convtranspose2d.weight
Parameter containing:
tensor([[[[0.1000, 0.2000],
# 测试反卷积层
>>> convtranspose2d(a)
tensor([[[[0.1000, 0.2000, 0.2000, 0.4000],
[0.5000, 0.8000, 1.0000, 1.6000],
[0.3000, 0.6000, 0.4000, 0.8000],
[1.5000, 2.4000, 2.0000, 3.2000]]]],

### 生成器的实现 (Generator)

```class GenerationModel(nn.Module):
"""生成虚假数据的模型"""
# 编码长度
EmbeddedSize = 128
def __init__(self):
super().__init__()
self.generator = nn.Sequential(
# 128,1,1 => 512,5,5
nn.ConvTranspose2d(128, 512, kernel_size=5, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
# => 256,10,10
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# => 128,20,20
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# => 64,40,40
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# => 3,80,80
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
# 限制输出在 -1 ~ 1，不使用 Hardtanh 是为了让超过范围的值可以传播给上层
nn.Tanh())
def forward(self, x):
y = self.generator(x.view(x.shape[0], x.shape[1], 1, 1))
return y```

### 识别器的实现 (Discriminator)

```class DiscriminationModel(nn.Module):
"""识别数据是否真实的模型"""
def __init__(self):
super().__init__()
self.discriminator = nn.Sequential(
# 3,80,80 => 64,40,40
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# => 128,20,20
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# => 256,10,10
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# => 512,5,5
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# => 1,1,1
nn.Conv2d(512, 1, kernel_size=5, stride=1, padding=0, bias=False),
# 扁平化
nn.Flatten(),
# 输出是否真实数据 (0 or 1)
nn.Sigmoid())
def forward(self, x):
y = self.discriminator(x)
return y```

### 训练生成器和识别器的方法

```# 创建模型实例
generation_model = GenerationModel().to(device)
discrimination_model = DiscriminationModel().to(device)
# 创建参数调整器
# 根据生成器和识别器分别创建
# 随机生成编码
def generate_vectors(batch_size):
vectors = torch.randn((batch_size, GenerationModel.EmbeddedSize), device=device)
return vectors
# 开始训练过程
for epoch in range(0, 10000):
# 枚举真实数据
# 生成随机编码
training_vectors = generate_vectors(minibatch_size)
# 生成虚假数据
generated = generation_model(training_vectors)
# 获取真实数据
real = batch_x

# 训练识别器 (只调整识别器的参数)
predicted_t = discrimination_model(real)
predicted_f = discrimination_model(generated)
loss_d = (
nn.functional.binary_cross_entropy(
predicted_t, torch.ones(predicted_t.shape, device=device)) +
nn.functional.binary_cross_entropy(
predicted_f, torch.zeros(predicted_f.shape, device=device)))
loss_d.backward() # 根据损失自动微分
optimizer_d.step() # 调整识别器的参数
# 训练生成器 (只调整生成器的参数)
predicted_f = discrimination_model(generated)
loss_g = nn.functional.binary_cross_entropy(
predicted_f, torch.ones(predicted_f.shape, device=device))
loss_g.backward() # 根据损失自动微分
optimizer_g.step() # 调整生成器的参数

## 改进对抗生成网络 (WGAN)

```# 计算识别器的损失，修改前
loss_d = (
nn.functional.binary_cross_entropy(
predicted_t, torch.ones(predicted_t.shape, device=device)) +
nn.functional.binary_cross_entropy(
predicted_f, torch.zeros(predicted_f.shape, device=device)))
# 计算识别器的损失，修改后
loss_d = predicted_f.mean() - predicted_t.mean()```

```# 计算生成器的损失，修改前
loss_g = nn.functional.binary_cross_entropy(
predicted_f, torch.ones(predicted_f.shape, device=device))
# 计算生成器的损失，修改后
loss_g = -predicted_f.mean()```

```# 让识别器参数必须在 -0.1 ~ 0.1 之间
for p in discrimination_model.parameters():
p.data.clamp_(-0.1, 0.1)```

## 改进对抗生成网络 (WGAN-GP)

WGAN 为了防止梯度爆炸问题对识别器参数的可取范围做出了限制，但这个做法比较粗暴，WGAN-GP (Wasserstein GAN Gradient Penalty) 提出了一个更优雅的方法，即限制导函数值的范围，如果导函数值偏移某个指定的值则通过损失给与模型惩罚。

```def gradient_penalty(discrimination_model, real, generated):
"""控制导函数值的范围，用于防止模型参数失控 (https://arxiv.org/pdf/1704.00028.pdf)"""
# 给批次中的每个样本分别生成不同的随机值，范围在 0 ~ 1
batch_size = real.shape[0]
rate = torch.randn(batch_size, 1, 1, 1)
rate = rate.expand(batch_size, real.shape[1], real.shape[2], real.shape[3]).to(device)
# 按随机值比例混合真样本和假样本
mixed = (rate * real + (1 - rate) * generated)
# 识别混合样本
predicted_m = discrimination_model(mixed)
# 计算 mixed 对 predicted_m 的影响，也就是 mixed => predicted_m 的微分
# 与以下代码计算结果相同，但不会影响途中 (即模型参数) 的 grad 值
# predicted_m.sum().backward()
outputs = predicted_m,
inputs = mixed,
create_graph=True,
retain_graph=True)[0]
# 让导函数值的 L2 norm (所有通道合计) 在 1 左右，如果偏离 1 则使用损失给与惩罚

```# 计算识别器的损失，修改前
loss_d = predicted_f.mean() - predicted_t.mean()
# 计算识别器的损失，修改后
loss_d = (predicted_f.mean() - predicted_t.mean() +

## 完整代码

https://www.kaggle.com/atulanandjha/lfwpeople

```import os
import sys
import torch
import gzip
import itertools
import random
import numpy
import math
import json
from PIL import Image
from torch import nn
from matplotlib import pyplot
from functools import lru_cache
# 生成或识别图片的大小
IMAGE_SIZE = (80, 80)
# 训练使用的数据集路径
DATASET_DIR = "./dataset/lfwpeople/lfw_funneled"
# 模型类别, 支持 DCGAN, WGAN, WGAN-GP
MODEL_TYPE = "WGAN-GP"
# 用于启用 GPU 支持
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GenerationModel(nn.Module):
"""生成虚假数据的模型"""
# 编码长度
EmbeddedSize = 128
def __init__(self):
super().__init__()
self.generator = nn.Sequential(
# 128,1,1 => 512,5,5
nn.ConvTranspose2d(128, 512, kernel_size=5, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
# => 256,10,10
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# => 128,20,20
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
# => 64,40,40
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# => 3,80,80
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
# 限制输出在 -1 ~ 1，不使用 Hardtanh 是为了让超过范围的值可以传播给上层
nn.Tanh())
def forward(self, x):
y = self.generator(x.view(x.shape[0], x.shape[1], 1, 1))
return y
@staticmethod
def calc_accuracy(predicted_f):
"""正确率计算器"""
# 返回骗过识别器的虚假数据比例
if MODEL_TYPE == "DCGAN":
threshold = 0.5
elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
threshold = DiscriminationModel.LastTrueSamplePredictedMean
else:
raise ValueError("unknown model type")
return (predicted_f >= threshold).float().mean().item()
class DiscriminationModel(nn.Module):
"""识别数据是否真实的模型"""
# 最终识别真实样本的输出平均值，WGAN 会使用这个值判断骗过识别器的虚假数据比例
LastTrueSamplePredictedMean = 0.5
def __init__(self):
super().__init__()
# 标准化函数
def norm2d(features):
if MODEL_TYPE == "WGAN-GP":
# WGAN-GP 本来不需要 BatchNorm，但可以额外的加 InstanceNorm 改善效果
# InstanceNorm 不一样的是平均值和标准差会针对批次中的各个样本分别计算
# affine = True 表示调整量可学习 (BatchNorm2d 默认为 True)
return nn.InstanceNorm2d(features, affine=True)
return nn.BatchNorm2d(features)
self.discriminator = nn.Sequential(
# 3,80,80 => 64,40,40
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# => 128,20,20
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
norm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# => 256,10,10
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
norm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# => 512,5,5
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
norm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# => 1,1,1
nn.Conv2d(512, 1, kernel_size=5, stride=1, padding=0, bias=False),
# 扁平化
nn.Flatten())
if MODEL_TYPE == "DCGAN":
# 输出是否真实数据 (0 or 1)
# WGAN 不限制输出值范围在 0 ~ 1 之间
def forward(self, x):
y = self.discriminator(x)
return y
@staticmethod
def calc_accuracy(predicted_f, predicted_t):
"""正确率计算器"""
# 返回正确识别的数据比例
if MODEL_TYPE == "DCGAN":
return (((predicted_f <= 0.5).float().mean() + (predicted_t > 0.5).float().mean()) / 2).item()
elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
DiscriminationModel.LastTrueSamplePredictedMean = predicted_t.mean()
return (predicted_t > predicted_f).float().mean().item()
else:
raise ValueError("unknown model type")
"""控制导函数值的范围，用于防止模型参数失控 (https://arxiv.org/pdf/1704.00028.pdf)"""
# 给批次中的每个样本分别生成不同的随机值，范围在 0 ~ 1
batch_size = real.shape[0]
rate = torch.randn(batch_size, 1, 1, 1)
rate = rate.expand(batch_size, real.shape[1], real.shape[2], real.shape[3]).to(device)
# 按随机值比例混合真样本和假样本
mixed = (rate * real + (1 - rate) * generated)
# 识别混合样本
predicted_m = self.forward(mixed)
# 计算 mixed 对 predicted_m 的影响，也就是 mixed => predicted_m 的微分
# 与以下代码计算结果相同，但不会影响途中 (即模型参数) 的 grad 值
# predicted_m.sum().backward()
outputs = predicted_m,
inputs = mixed,
create_graph=True,
retain_graph=True)[0]
# 让导函数值的 L2 norm (所有通道合计) 在 1 左右，如果偏离 1 则使用损失给与惩罚
def save_tensor(tensor, path):
"""保存 tensor 对象到文件"""
torch.save(tensor, gzip.GzipFile(path, "wb"))
# 为了减少读取时间这里缓存了读取的 tensor 对象
# 如果内存不够应该适当减少 maxsize
@lru_cache(maxsize=200)
"""从文件读取 tensor 对象"""
def image_to_tensor(img):
"""缩放并转换图片对象到 tensor 对象"""
img = img.resize(IMAGE_SIZE) # 缩放图片，比例不一致时拉伸
arr = numpy.asarray(img)
t = torch.from_numpy(arr)
t = t.transpose(0, 2) # 转换维度 H,W,C 到 C,W,H
t = (t / 255.0) * 2 - 1 # 正规化数值使得范围在 -1 ~ 1
return t
def tensor_to_image(t):
"""转换 tensor 对象到图片"""
t = (t + 1) / 2 * 255.0 # 转换颜色回 0 ~ 255
t = t.transpose(0, 2) # 转换维度 C,W,H 到 H,W,C
t = t.int() # 转换数值到整数
img = Image.fromarray(t.numpy().astype("uint8"), "RGB")
return img
def prepare():
"""准备训练"""
# 数据集转换到 tensor 以后会保存在 data 文件夹下
if not os.path.isdir("data"):
os.makedirs("data")
# 查找人脸图片列表
# 每个人最多使用 2 张图片
image_paths = []
for dirname in os.listdir(DATASET_DIR):
dirpath = os.path.join(DATASET_DIR, dirname)
if not os.path.isdir(dirpath):
continue
for filename in os.listdir(dirpath)[:2]:
image_paths.append(os.path.join(DATASET_DIR, dirname, filename))
print(f"found {len(image_paths)} images")
# 随机打乱人脸图片列表
random.shuffle(image_paths)
# 限制人脸数量
# 如果数量太多，识别器难以记住人脸的具体特征，会需要更长时间训练或直接陷入模式崩溃问题
image_paths = image_paths[:2000]
print(f"only use {len(image_paths)} images")
# 保存人脸图片数据
for batch, index in enumerate(range(0, len(image_paths), 200)):
paths = image_paths[index:index+200]
images = []
for path in paths:
img = Image.open(path)
# 扩大人脸占比
w, h = img.size
img = img.crop((int(w*0.25), int(h*0.25), int(w*0.75), int(h*0.75)))
images.append(img)
tensors = [ image_to_tensor(img) for img in images ]
tensor = torch.stack(tensors) # 维度: (图片数量, 3, 宽度, 高度)
save_tensor(tensor, os.path.join("data", f"{batch}.pt"))
print(f"saved batch {batch}")
print("done")
def train():
"""开始训练模型"""
# 创建模型实例
generation_model = GenerationModel().to(device)
discrimination_model = DiscriminationModel().to(device)
# 创建损失计算器
ones_map = {}
zeros_map = {}
def loss_function_t(predicted):
"""损失计算器 (训练识别结果为 1)"""
count = predicted.shape[0]
ones = ones_map.get(count)
if ones is None:
ones = torch.ones((count, 1), device=device)
ones_map[count] = ones
return nn.functional.binary_cross_entropy(predicted, ones)
def loss_function_f(predicted):
"""损失计算器 (训练识别结果为 0)"""
count = predicted.shape[0]
zeros = zeros_map.get(count)
if zeros is None:
zeros = torch.zeros((count, 1), device=device)
zeros_map[count] = zeros
return nn.functional.binary_cross_entropy(predicted, zeros)
# 创建参数调整器
# 学习率和 betas 跟各个论文给出的一样，可以一定程度提升学习效果，但不是决定性的
if MODEL_TYPE == "DCGAN":
optimizer_g = torch.optim.Adam(generation_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discrimination_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
elif MODEL_TYPE == "WGAN":
optimizer_g = torch.optim.RMSprop(generation_model.parameters(), lr=0.00005)
optimizer_d = torch.optim.RMSprop(discrimination_model.parameters(), lr=0.00005)
elif MODEL_TYPE == "WGAN-GP":
optimizer_g = torch.optim.Adam(generation_model.parameters(), lr=0.0001, betas=(0.0, 0.999))
optimizer_d = torch.optim.Adam(discrimination_model.parameters(), lr=0.0001, betas=(0.0, 0.999))
else:
raise ValueError("unknown model type")
# 记录训练集和验证集的正确率变化
training_accuracy_g_history = []
training_accuracy_d_history = []
# 计算正确率的工具函数
calc_accuracy_g = generation_model.calc_accuracy
calc_accuracy_d = discrimination_model.calc_accuracy
# 随机生成编码
def generate_vectors(batch_size):
vectors = torch.randn((batch_size, GenerationModel.EmbeddedSize), device=device)
return vectors
# 输出生成的图片样本
def output_generated_samples(epoch, samples):
dir_path = f"./generated_samples/{epoch}"
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
for index, sample in enumerate(samples):
path = os.path.join(dir_path, f"{index}.png")
tensor_to_image(sample.cpu()).save(path)
# 读取批次的工具函数
for batch in itertools.count():
path = f"data/{batch}.pt"
if not os.path.isfile(path):
break
yield x.to(device)
# 开始训练过程
validating_vectors = generate_vectors(100)
for epoch in range(0, 10000):
print(f"epoch: {epoch}")
# 根据训练集训练并修改参数
# 切换模型到训练模式
generation_model.train()
discrimination_model.train()
training_accuracy_g_list = []
training_accuracy_d_list = []
last_accuracy_g = 0
last_accuracy_d = 0
minibatch_size = 20
train_discriminator_count = 0
# 使用小批次训练
training_batch_accuracy_g = 0.0
training_batch_accuracy_d = 0.0
minibatch_count = 0
for begin in range(0, batch_x.shape[0], minibatch_size):
# 测试目前生成器和识别器哪边占劣势，训练占劣势的一方
# 最终的平衡状态是: 生成器正确率 = 1.0, 识别器正确率 = 0.5
# 代表生成器生成的图片和真实图片基本完全一样，但不应该训练到这个程度
training_vectors = generate_vectors(minibatch_size) # 随机向量
generated = generation_model(training_vectors) # 根据随机向量生成的虚假数据
real = batch_x[begin:begin+minibatch_size] # 真实数据
predicted_t = discrimination_model(real)
predicted_f = discrimination_model(generated)
accuracy_g = calc_accuracy_g(predicted_f)
accuracy_d = calc_accuracy_d(predicted_f, predicted_t)
train_discriminator = (accuracy_g / 2) >= accuracy_d
if train_discriminator or train_discriminator_count > 0:
# 训练识别器
if MODEL_TYPE == "DCGAN":
loss_d = loss_function_f(predicted_f) + loss_function_t(predicted_t)
elif MODEL_TYPE == "WGAN":
loss_d = predicted_f.mean() - predicted_t.mean()
elif MODEL_TYPE == "WGAN-GP":
loss_d = (predicted_f.mean() - predicted_t.mean() +
else:
raise ValueError("unknown model type")
loss_d.backward()
optimizer_d.step()
# 限制识别器参数范围以防止模型参数失控 (WGAN-GP 有更好的方法)
# 这里的限制值比论文的值 (0.01) 更大是因为模型层数和参数量更多
if MODEL_TYPE == "WGAN":
for p in discrimination_model.parameters():
p.data.clamp_(-0.1, 0.1)
# 让识别器训练次数多于生成器
if train_discriminator and train_discriminator_count == 0:
train_discriminator_count = 5
train_discriminator_count -= 1
else:
# 训练生成器
if MODEL_TYPE == "DCGAN":
loss_g = loss_function_t(predicted_f)
elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
loss_g = -predicted_f.mean()
else:
raise ValueError("unknown model type")
loss_g.backward()
optimizer_g.step()
training_batch_accuracy_g += accuracy_g
training_batch_accuracy_d += accuracy_d
minibatch_count += 1
training_batch_accuracy_g /= minibatch_count
training_batch_accuracy_d /= minibatch_count
# 输出批次正确率
training_accuracy_g_list.append(training_batch_accuracy_g)
training_accuracy_d_list.append(training_batch_accuracy_d)
print(f"epoch: {epoch}, batch: {index},",
f"accuracy_g: {training_batch_accuracy_g}, accuracy_d: {training_batch_accuracy_d}")
training_accuracy_g = sum(training_accuracy_g_list) / len(training_accuracy_g_list)
training_accuracy_d = sum(training_accuracy_d_list) / len(training_accuracy_d_list)
training_accuracy_g_history.append(training_accuracy_g)
training_accuracy_d_history.append(training_accuracy_d)
print(f"training accuracy_g: {training_accuracy_g}, accuracy_d: {training_accuracy_d}")
# 保存虚假数据用于评价训练效果
output_generated_samples(epoch, generation_model(validating_vectors))
# 保存模型状态
if (epoch + 1) % 10 == 0:
save_tensor(generation_model.state_dict(), "model.generation.pt")
save_tensor(discrimination_model.state_dict(), "model.discrimination.pt")
if (epoch + 1) % 100 == 0:
save_tensor(generation_model.state_dict(), f"model.generation.epoch_{epoch}.pt")
save_tensor(discrimination_model.state_dict(), f"model.discrimination.epoch_{epoch}.pt")
print("model saved")
print("training finished")
# 显示训练集的正确率变化
pyplot.plot(training_accuracy_g_history, label="training_accuracy_g")
pyplot.plot(training_accuracy_d_history, label="training_accuracy_d")
pyplot.ylim(0, 1)
pyplot.legend()
pyplot.show()
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs
from io import BytesIO
class RequestHandler(BaseHTTPRequestHandler):
"""用于测试生成图片的简单服务器"""
# 模型状态的路径，这里使用看起来效果最好的记录
MODEL_STATE_PATH = "model.generation.epoch_2999.pt"
Model = None
@staticmethod
def get_model():
if RequestHandler.Model is None:
# 创建模型实例，加载训练好的状态，然后切换到验证模式
model = GenerationModel().to(device)
model.eval()
RequestHandler.Model = model
return RequestHandler.Model
def do_GET(self):
parts = self.path.partition("?")
if parts[0] == "/":
self.send_response(200)
with open("gan_eval.html", "rb") as f:
elif parts[0] == "/generate":
# 根据传入的参数生成图片
params = parse_qs(parts[-1])
vector = (torch.tensor([float(x) for x in params["values"][0].split(",")])
.reshape(1, GenerationModel.EmbeddedSize)
.to(device))
generated = RequestHandler.get_model()(vector)[0]
img = tensor_to_image(generated.cpu())
bytes_io = BytesIO()
img.save(bytes_io, format="PNG")
# 返回图片
self.send_response(200)
self.wfile.write(bytes_io.getvalue())
else:
self.send_response(404)
def eval_model():
"""使用训练好的模型生成图片"""
server = HTTPServer(("localhost", 8666), RequestHandler)
try:
server.serve_forever()
except KeyboardInterrupt:
pass
server.server_close()
exit()
def main():
"""主函数"""
if len(sys.argv) < 2:
exit()
# 给随机数生成器分配一个初始值，使得每次运行都可以生成相同的随机数
# 这是为了让过程可重现，你也可以选择不这样做
random.seed(0)
torch.random.manual_seed(0)
# 根据命令行参数选择操作
operation = sys.argv[1]
if operation == "prepare":
prepare()
elif operation == "train":
train()
elif operation == "eval":
eval_model()
else:
raise ValueError(f"Unsupported operation: {operation}")
if __name__ == "__main__":
main()```

```python3 gan.py prepare
python3 gan.py train```

DCGAN

WGAN

WGAN-GP

WGAN-GP 训练到 3000 次以后输出的样本如下：

WGAN-GP 训练到 10000 次以后输出的样本如下：

```generation_model = GenerationModel().to(device)
model.eval()
# 随机生成 100 张人脸
vector = torch.randn((100, GenerationModel.EmbeddedSize), device=device)
samples = model(vector)
for index, sample in enumerate(samples):
img = tensor_to_image(sample.cpu())
img.save(f"{index}.png")```

```<!DOCTYPE html>
<html lang="cn">
<meta charset="utf-8">
<title>测试人脸生成</title>
<style>
html, body {
width: 100%;
height: 100%;
margin: 0px;
}
.left-pane {
width: 50%;
height: 100%;
border-right: 1px solid #000;
}
.right-pane {
position: fixed;
left: 70%;
top: 35%;
width: 25%;
}
.sliders {
}
.slider-container {
display: inline-block;
min-width: 25%;
}
#image {
left: 25%;
top: 25%;
width: 50%;
height: 50%;
}
</style>
<body>
<div>
<div>
</div>
</div>
<div>
<p><img src="data:image/png;base64," alt="image" /></p>
<p><button>随机生成</button></p>
</div>
</body>
<script>
(function() {
// 滑动条改变后的处理
var onChanged = function() {
var sliderInputs = document.querySelectorAll(".slider");
var values = [];
sliderInputs.forEach(function(s) {
values.push(s.value);
});
var image = document.querySelector("#target");
image.setAttribute("src", "/generate?values=" + values.join(","));
};
// 点击随机生成时的处理
var setRandomButton = document.querySelector(".set-random");
setRandomButton.onclick = function() {
var sliderInputs = document.querySelectorAll(".slider");
sliderInputs.forEach(function(s) { s.value = Math.random() * 2 - 1; });
onChanged();
};
// 添加滑动条
var sliders = document.querySelector(".sliders");
for (var n = 0; n < 128; ++n) {
var container = document.createElement("div");
container.setAttribute("class", "slider-container");
var span = document.createElement("span");
span.innerText = n;
container.appendChild(span);
var slider = document.createElement("input");
slider.setAttribute("type", "range")
slider.setAttribute("class", "slider");
slider.setAttribute("min", "-1");
slider.setAttribute("max", "1");
slider.setAttribute("step", "0.01");
slider.value = 0;
slider.onchange = onChanged;
slider.oninput = onChanged;
container.appendChild(slider);
sliders.appendChild(container);
}
})();
</script>
</html>```

`python3 gan.py eval`