Press "Enter" to skip to content

Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法

作者丨皮特潘

前言

 

BN(Batch Normalization)几乎是目前神经网络的必选组件,但是使用BN有两个前提要求:

 

 

batchsize不能太小;

 

每一个minibatch和整体数据集同分布。

 

 

不然的话,非但不能发挥BN的优势,甚至会适得其反。但是由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文介绍两篇在小batchsize也可以发挥BN性能的方法。解决思路为:既然batchsize太小的情况下,无法保证当前minibatch收集到的数据和整体数据同分布。那幺能否多收集几个batch的数据进行统计呢?这两篇工作分别分别是:

 

BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models

 

CBN:Cross-Iteration Batch Normalization

 

另外,本文也会给出代码解析,帮助大家理解。

 

batchsize过小的场景

 

通常情况下,大家对CNN任务的研究一般为公开的数据集指标负责。分类任务为ImageNet数据集负责,其尺度为224X224。检测任务为coco数据集负责,其尺度为640X640左右。分割任务一般为coco或PASCAL VOC数据集负责,后者的尺度大概在500X500左右。再加上例如resize的前处理操作,真正送入网络的图片的分辨率都不算太大。一般性能的GPU也很容易实现大的batchsize(例如大于32)的支持。

 

但是实际的项目中,经常遇到需要处理的图片尺度过大的场景,例如我们使用500w像素甚至2000w像素的工业相机进行数据采集,500w的相机采集的图片尺度就是2500X2000左右。而对于微小的缺陷检测、高精度的关键点检测或小物体的目标检测等任务,我们一般不太想粗暴降低输入图片的分辨率,这样违背了我们使用高分辨率相机的初衷,也可能导致丢失有用特征。在算力有限的情况下,我们的batchsize就无法设置太大,甚至只能为1或2。小的batchsize会带来很多训练上的问题,其中BN问题就是最突出的。虽然大batchsize训练是一个共识,但是现实中可能无法具有充足的资源,因此我们需要一些处理手段。

 

BN回顾

首先Batch Normalization 中的Normalization被称为标准化,通过将数据进行平和缩放拉到一个特定的分布。BN就是在batch维度上进行数据的标准化。BN的引入是用来解决 internal covariate shift 问题,即训练迭代中网络激活的分布的变化对网络训练带来的破坏。BN通过在每次训练迭代的时候,利用minibatch计算出的当前batch的均值和方差,进行标准化来缓解这个问题。虽然How Does Batch Normalization Help Optimization 这篇文章探究了BN其实和Internal Covariate Shift (ICS)问题关系不大,本文不深入讨论,这个会在以后的文章中细说。

 

一般来说,BN有两个优点:

 

降低对初始化、学习率等超参的敏感程度,因为每层的输入被BN拉成相对稳定的分布,也能加速收敛过程。

 

应对梯度饱和和梯度弥散,主要是对于使用sigmoid和tanh的激活函数的网络。

 

当然,BN的使用也有两个前提:

 

minibatch和全部数据同分布。因为训练过程每个minibatch从整体数据中均匀采样,不同分布的话minibatch的均值和方差和训练样本整体的均值和方差是会存在较大差异的,在测试的时候会严重影响精度。

 

batchsize不能太小,否则效果会较差,论文给的一般性下限是32。

 

再来回顾一下BN的具体做法:

训练的时候:使用当前batch统计的均值和方差对数据进行标准化,同时优化优化gamma和beta两个参数。另外利用指数滑动平均收集全局的均值和方差。

 

测试的时候:使用训练时收集全局均值和方差以及优化好的gamma和beta进行推理。

 

可以看出,要想BN真正work,就要保证训练时当前batch的均值和方差逼近全部数据的均值和方差。

 

BRN

 

论文题目:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models

 

论文地址:https://arxiv.org/pdf/1702.03275.pdf

 

代码地址:https://github.com/ludvb/batchrenorm

 

核心解析:

 

本文的核心思想就是:训练过程中,由于batchsize较小,当前minibatch统计到的均值和方差与全部数据有差异,那幺就对当前的均值和方差进行修正。修正的方法主要是利用到通过滑动平均收集到的全局均值和标准差。看公式:

 

上面公式中,i表示网络的第i层。μ和σ表示网络推理时的均值和标准差,也就是训练过程通过滑动平均收集的到均值和方差。μB和σb表示当前训练迭代过程中的实际统计到的均值和标准差。BN在小batch不work的根本原因就是这两组参数存在较大的差异。通过r和d对训练过程中数据进行线性变换,在该变化下,上公式左右两端就严格相等了。其实标准的BN就是r=1,d=0的一种情况。对于某一个特定的minibatch,其中r和d可以看成是固定的,是直接计算出来的,不需要梯度优化的。

 

具体流程:

统计当前batch数据的均值和标注差,和标准BN做法一致。

 

根据当前batch的均值和标准差结合全局的均值和标准差利用上面的公式计算r和d;注意该运算是不参与梯度反向传播的。另外,r和d需要增加一个限制,直接clip操作就好。

 

利用当前的均值和标准差对当前数据执行Normalization操作,利用上面计算得到的r和d对当前batch进行线性变换。

 

滑动平均收集全局均值和标注差。

 

测试过程和标准BN一样。其实本质上,就是训练的过程中使用全局的信息进行更新当前batch的数据。间接利用了全局的信息,而非当前这一个batch的信息。

 

实验效果:

 

在较大的batchsize(32)的时候,与标准BN相比,不会丢失效果,训练过程一如既往稳定高效。如下:

在小的batchsize(4)下, 本文做法依然接近batchsize为32的时候,可见在小batchsize下是work的。

代码解析:

 

def forward(self, x):
    if x.dim() > 2:
      x = x.transpose(1, -1)
    if self.training: # 训练过程
      dims = [i for i in range(x.dim() - 1)
      batch_mean = x.mean(dims) # 计算均值
      batch_std = x.std(dims, unbiased=False) + self.eps # 计算标准差
      # 按照公式计算r和d
      r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)
      d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean))
            / self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)
      # 对当前数据进行标准化和线性变换
      x = (x - batch_mean) / batch_std * r + d
      # 滑动平均收集全局均值和标注差
      self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
      self.running_std += self.momentum * (batch_std.detach() - self.running_std)
      self.num_batches_tracked += 1
    else: # 测试过程
    x = (x - self.running_mean) / self.running_std
  return x


CBN

 

论文题目:Cross-Iteration Batch Normalization

 

论文地址:https://arxiv.org/abs/2002.05712

 

代码地址:https://github.com/Howal/Cross-iterationBatchNorm

本文认为BRN的问题在于它使用的全局均值和标准差不是当前网络权重下获取的,因此不是exactly正确的,所以batchsize再小一点,例如为1或2时就不太work了。本文使用泰勒多项式逼近原理来修正当前的均值和标准差,同样也是间接利用了全局的均值和方差信息。简述就是:当前batch的均值和方差来自之前的K次迭代均值和方差的平均,由于网络权重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估计前面的迭代在当前权重下的数值。

 

泰勒公式:

 

泰勒公式是 一 个用函数在某点的信息描述其附近取值的公式。如果函数满足 一 定的条件,泰勒公式可以用函数在某 一 点的各阶导数值做系数构建 一 个多项式来近似表达这个函数。教科书介绍如下:

核心解析:

 

本文做法,由于网络一般使用SGD更新权重,因此网络权重的变化是平滑的,所以适用泰勒公式。如下,t为训练过程中当前迭代时刻,t-τ为t时刻向前τ时刻。θ为网络权重,权重下标代表该权重的时刻。μ为当前minibatch均值,v为当强minibatch平方的均值,是为了计算标准差。因此直接套用泰勒公式得到:

 

上面这两个公式就是为了估计在t-τ时刻,t时刻的权重下的均值和方差的参数估计。BRN可以看作没有进行该方法估计,使用的依然是t-τ时刻权重的参数估计。其中O为高阶项,因为该式主要由一阶项控制,因此高阶项目可以忽略。上面的公式还要进一步简化,主要是偏导项的求法。假设当前层为l,实际上∂μ/ ∂θ 和 ∂ν/∂θ依赖与所有l层之前层的权重,求导计算量极大。不过经验观察到,l层之前层的偏数下降很快,因此可以忽略掉,仅仅计算当前层的权重偏导。

因此化简为如下,可以看出,求偏导的部分,只考虑对当前层的偏导数,注意上标l表示网络层的意思。至此,之前时刻在当前权重下的均值和方差已经估计出来了。

 

下面穿插代码解析整个计算过程。

 

首先是统计计算当前batch的数据,和标准BN没有差别。代码为:

 

cur_mu = y.mean(dim=1)  # 当前层的均值
cur_meanx2 = torch.pow(y, 2).mean(dim=1)  # 当前值平方的均值,计算标准差使用
cur_sigma2 = y.var(dim=1)  # 当前值的方差

 

对当前网络层求偏导,直接使用torch的内置函数。代码:

 

# 注意 grad_outputs = self.ones : 不同值的梯度对结果影响程度不同,类似torch.sum()的作用。
dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]
dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]

 

使用公式(7)和(8)继续下面的计算,也就是向前累计K次估计数值,更新到当前batch的均值和方差的计算上,这里引入了一个超参就是k的大小,它表示当前的迭代向后回溯到多长的步长的迭代。实验探究k=8是一个比较折中的选择。k=1的时候,RBN退化成了原始的BN:

 

代码如下,其中这里的self.pre_mu, self.pre_dmudw, self.pre_weight是前面每次迭代收集到了窗口k大小的数值,分别代表均值、均值对权重的偏导、权重。self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight同理,是对应平方均值的。

 

# 利用泰勒公式估计
mu_all = torch.stack \
  ([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])


meanx2_all = torch.stack \
  ([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])

 

上面所说的变量收集迭代过程如下:

 

# 动态维护buffer_num长度的均值、均值平方、偏导、权重
self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]
self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]
self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]
self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]
tmp_weight = torch.zeros_like(weight.data)
tmp_weight.copy_(weight.data)
self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]

 

计算获取当前batch的均值和方差,取修正后的K次迭代数据的平均即可。

 

# 利用收集到的一定窗口长度的均值和平方均值,计算当前均值和方差
sigma2_all = meanx2_all - torch.pow(mu_all, 2)
re_mu_all = mu_all.clone()
re_meanx2_all = meanx2_all.clone()
re_mu_all[sigma2_all < 0] = 0
re_meanx2_all[sigma2_all < 0] = 0
count = (sigma2_all >= 0).sum(dim=0).float()
mu = re_mu_all.sum(dim=0) / count # 平均操作
sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)

 

均值和方差使用过程,和标准BN没有区别。

 

# 标准化过程,和原始BN没有区别
y = y - mu.view(-1, 1)
if self.out_p:  # 仅仅控制开平方的位置
  y = y / (sigma2.view(-1, 1) + self.eps) ** .5
else:
  y = y / (sigma2.view(-1, 1) ** .5 + self.eps)

 

最后再理解一下:

 

mu_0是当前batch统计获取的均值,mu_1是上一batch统计获取的均值。当前batch计算BN的时候也想利用到mu_1,但是统计mu_1的时候利用到网络的权重也是上一次的,直接使用肯定有问题,所以本文使用泰勒公式估计出mu_1在当前权重下应该是什幺样子。方差估计同理。

 

实验效果:

 

这里的Naive CBN 是上一篇论文BRN的做法,可以认为是CBN不使用泰勒估计的一种特例。在batchsize下降的过程中,CBN指标依然坚挺,甚至超过了GN(不过也侧面反应了GN确实厉害)。而原始BN和其改进版BRN在batchsize更小的时候都不太work了。

◎ 作者档案

 

皮特潘,致力于AI落地而上下求索

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注