## 前言

BN（Batch Normalization）几乎是目前神经网络的必选组件，但是使用BN有两个前提要求：

batchsize不能太小；

BRN：Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models

CBN：Cross-Iteration Batch Normalization

## BN回顾

minibatch和全部数据同分布。因为训练过程每个minibatch从整体数据中均匀采样，不同分布的话minibatch的均值和方差和训练样本整体的均值和方差是会存在较大差异的，在测试的时候会严重影响精度。

batchsize不能太小，否则效果会较差，论文给的一般性下限是32。

## BRN

#### 代码解析：

````def forward(self, x):`
`    if x.dim() > 2:`
`      x = x.transpose(1, -1)`
`    if self.training: # 训练过程`
`      dims = [i for i in range(x.dim() - 1)`
`      batch_mean = x.mean(dims) # 计算均值`
`      batch_std = x.std(dims, unbiased=False) + self.eps # 计算标准差`
`      # 按照公式计算r和d`
`      r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)`
`      d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean))`
`            / self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)`
`      # 对当前数据进行标准化和线性变换`
`      x = (x - batch_mean) / batch_std * r + d`
`      # 滑动平均收集全局均值和标注差`
`      self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)`
`      self.running_std += self.momentum * (batch_std.detach() - self.running_std)`
`      self.num_batches_tracked += 1`
`    else: # 测试过程`
`    x = (x - self.running_mean) / self.running_std`
`  return x````

## CBN

#### 核心解析：

````cur_mu = y.mean(dim=1)  # 当前层的均值`
`cur_meanx2 = torch.pow(y, 2).mean(dim=1)  # 当前值平方的均值，计算标准差使用`
`cur_sigma2 = y.var(dim=1)  # 当前值的方差````

````# 注意 grad_outputs = self.ones : 不同值的梯度对结果影响程度不同，类似torch.sum()的作用。`
`dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]`
`dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]````

````# 利用泰勒公式估计`
`mu_all = torch.stack \`
`  ([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])`
```
```
`meanx2_all = torch.stack \`
`  ([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])````

````# 动态维护buffer_num长度的均值、均值平方、偏导、权重`
`self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]`
`self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]`
`self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]`
`self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]`
`tmp_weight = torch.zeros_like(weight.data)`
`tmp_weight.copy_(weight.data)`
`self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]````

````# 利用收集到的一定窗口长度的均值和平方均值，计算当前均值和方差`
`sigma2_all = meanx2_all - torch.pow(mu_all, 2)`
`re_mu_all = mu_all.clone()`
`re_meanx2_all = meanx2_all.clone()`
`re_mu_all[sigma2_all < 0] = 0`
`re_meanx2_all[sigma2_all < 0] = 0`
`count = (sigma2_all >= 0).sum(dim=0).float()`
`mu = re_mu_all.sum(dim=0) / count # 平均操作`
`sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)````

````# 标准化过程，和原始BN没有区别`
`y = y - mu.view(-1, 1)`
`if self.out_p:  # 仅仅控制开平方的位置`
`  y = y / (sigma2.view(-1, 1) + self.eps) ** .5`
`else:`
`  y = y / (sigma2.view(-1, 1) ** .5 + self.eps)````

#### 最后再理解一下：

mu_0是当前batch统计获取的均值，mu_1是上一batch统计获取的均值。当前batch计算BN的时候也想利用到mu_1，但是统计mu_1的时候利用到网络的权重也是上一次的，直接使用肯定有问题，所以本文使用泰勒公式估计出mu_1在当前权重下应该是什幺样子。方差估计同理。

◎ 作者档案