### 批标准化算法

#### 全连接层

```mean = torch.mean(X, axis=0)
variance = torch.mean((X-mean)**2, axis=0)
X_hat = (X-mean) * 1.0 /torch.sqrt(variance + eps)
out = gamma * X_hat + beta```

#### 卷积层

```N, C, H, W = X.shape
mean = torch.mean(X, axis = (0, 2, 3))
variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + eps)
out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))```

### 最后一个模块

```class CustomBatchNorm(nn.Module):
def __init__(self, in_size, momentum=0.9, eps = 1e-5):
super(CustomBatchNorm, self).__init__()

self.momentum = momentum
self.insize = in_size
self.eps = eps

U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
self.gamma = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
self.beta = nn.Parameter(torch.zeros(self.insize))

self.register_buffer('running_mean', torch.zeros(self.insize))
self.register_buffer('running_var', torch.ones(self.insize))

self.running_mean.zero_()
self.running_var.fill_(1)
def forward(self, input):

X = input
if len(X.shape) not in (2, 4):
raise ValueError("only support dense or 2dconv")

#全连接层
elif len(X.shape) == 2:
if self.training:
mean = torch.mean(X, axis=0)
variance = torch.mean((X-mean)**2, axis=0)

self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)

else:
mean = self.running_mean
variance = self.running_var

X_hat = (X-mean) * 1.0 /torch.sqrt(variance + self.eps)
out = self.gamma * X_hat + self.beta
# 卷积层
elif len(X.shape) == 4:
if self.training:
N, C, H, W = X.shape
mean = torch.mean(X, axis = (0, 2, 3))
variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))

self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean
self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*variance)
else:
mean = self.running_mean
var = self.running_var

X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + self.eps)
out = self.gamma.reshape((1, C, 1, 1)) * X_hat + self.beta.reshape((1, C, 1, 1))

return out```

### 实验MNIST

```class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 10)
)

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x```

```class SimpleNetBN(nn.Module):
def __init__(self):
super(SimpleNetBN, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(28 * 28, 64),
CustomBatchNorm(64),
nn.ReLU(),
nn.Linear(64, 128),
CustomBatchNorm(128),
nn.ReLU(),
nn.Linear(128, 10)
)

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x```