Press "Enter" to skip to content

InstanceNorm 与 BatchNorm 的联系

InstanceNorm 训练与预测阶段行为一致，都是利用当前 batch 的均值和方差计算

BatchNorm 训练阶段利用当前 batch 的均值和方差，测试阶段则利用训练阶段通过移动平均统计的均值和方差

https://arxiv.org/pdf/1803.08494.pdf

梯度推导过程详解

loss 函数的符号约定为

gamma 和 beta 参数梯度的推导

深度学习框架实现代码解读

PyTroch 前向传播实现

https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L506

`Tensor instance_norm(const Tensor& input,const Tensor& weight/* optional */,const Tensor& bias/* optional */,const Tensor& running_mean/* optional */,const Tensor& running_var/* optional */,bool use_input_stats,double momentum,double eps,bool cudnn_enabled) {// ......std::vector<int64_t> shape =input.sizes().vec();int64_t b = input.size(0);int64_t c = input.size(1);// shape 从 (b, c, h, w)// 变为 (1, b*c, h, w)shape[1] = b * c;shape[0] = 1;// repeat_if_defined 的解释见下文Tensor weight_ =repeat_if_defined(weight, b);Tensor bias_ =repeat_if_defined(bias, b);Tensor running_mean_ =repeat_if_defined(running_mean, b);Tensor running_var_ =repeat_if_defined(running_var, b);// 改变输入张量的形状auto input_reshaped =input.contiguous().view(shape);// 计算实际调用的是 batchnorm 的实现// 所以可以理解为什幺 pytroch// 前端 InstanceNorm2d 的接口// 与 BatchNorm2d 的接口一样auto out = at::batch_norm(input_reshaped,weight_, bias_,running_mean_,running_var_,use_input_stats,momentum,eps, cudnn_enabled);// ......return out.view(input.sizes());}`

repeat_if_defined 的代码：

https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27

`static inline Tensor repeat_if_defined(const Tensor& t,int64_t repeat) {if (t.defined()) {// 把 tensor 按第0维度复制 repeat 次return t.repeat(repeat);}return t;}`

MXNet 反向传播实现

https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112

`template<typename xpu>void InstanceNormBackward(const nnvm::NodeAttrs& attrs,const OpContext &ctx,const std::vector<TBlob> &inputs,const std::vector<OpReqType> &req,const std::vector<TBlob> &outputs) {using namespace mshadow;using namespace mshadow::expr;// ......const InstanceNormParam& param =nnvm::get<InstanceNormParam>(attrs.parsed);Stream<xpu> *s =ctx.get_stream<xpu>();// 获取输入张量的形状mxnet::TShape dshape =inputs[3].shape_;// ......int n = inputs[3].size(0);int c = inputs[3].size(1);// rest_dim 就等于上文的 Mint rest_dim =static_cast<int>(inputs[3].Size() / n / c);Shape<2> s2 = Shape2(n * c, rest_dim);Shape<3> s3 = Shape3(n, c, rest_dim);// scale 就等于上文的 1/Mconst real_t scale =static_cast<real_t>(1) /static_cast<real_t>(rest_dim);// 获取输入张量Tensor<xpu, 2> data = inputs[3].get_with_shape<xpu, 2, real_t>(s2, s);// 保存输入梯度Tensor<xpu, 2> gdata = outputs[kData].get_with_shape<xpu, 2, real_t>(s2, s);// 获取参数 gammaTensor<xpu, 1> gamma =inputs[4].get<xpu, 1, real_t>(s);// 保存参数 gamma 梯度计算结果Tensor<xpu, 1> ggamma = outputs[kGamma].get<xpu, 1, real_t>(s);// 保存参数 beta 梯度计算结果Tensor<xpu, 1> gbeta = outputs[kBeta].get<xpu, 1, real_t>(s);// 获取输出梯度Tensor<xpu, 2> gout = inputs[0].get_with_shape<xpu, 2, real_t>(s2, s);// 获取前向计算好的均值和方差Tensor<xpu, 1> var =inputs[2].FlatTo1D<xpu, real_t>(s);Tensor<xpu, 1> mean =inputs[1].FlatTo1D<xpu, real_t>(s);// 临时空间Tensor<xpu, 2> workspace = //.....// 保存均值的梯度Tensor<xpu, 1> gmean = workspace[0];// 保存方差的梯度Tensor<xpu, 1> gvar = workspace[1];Tensor<xpu, 1> tmp = workspace[2];// 计算方差的梯度，// 对应上文输入梯度公式的第三项// gout 对应输出梯度gvar = sumall_except_dim<0>((gout * broadcast<0>(reshape(repmat(gamma, n),Shape1(n * c)), data.shape_)) *(data - broadcast<0>(mean, data.shape_)) * -0.5f *F<mshadow_op::power>(broadcast<0>(var + param.eps, data.shape_),-1.5f));// 计算均值的梯度，// 对应上文输入梯度公式的第二项gmean = sumall_except_dim<0>(gout * broadcast<0>(reshape(repmat(gamma, n),Shape1(n * c)), data.shape_));gmean *=-1.0f / F<mshadow_op::square_root>(var + param.eps);tmp = scale * sumall_except_dim<0>(-2.0f * (data - broadcast<0>(mean, data.shape_)));tmp *= gvar;gmean += tmp;// 计算 beta 的梯度// 记得s3 = Shape3(n, c, rest_dim)// 那幺swapaxis<1, 0>(reshape(gout, s3))// 就表示首先把输出梯度 reshape 成// (n, c, rest_dim)，接着交换第0和1维度// (c, n, rest_dim)，最后求除了第0维度// 之外其他维度的和，// 也就和 beta 的求导公式对应上了Assign(gbeta, req[kBeta],sumall_except_dim<0>(swapaxis<1, 0>(reshape(gout, s3))));// 计算 gamma 的梯度// swapaxis<1, 0> 的作用与上面 beta 一样Assign(ggamma, req[kGamma],sumall_except_dim<0>(swapaxis<1, 0>(reshape(gout *(data - broadcast<0>(mean,data.shape_))/ F<mshadow_op::square_root>(broadcast<0>(var + param.eps,data.shape_)), s3))));// 计算输入的梯度，// 对应上文输入梯度公式三项的相加Assign(gdata, req[kData],(gout * broadcast<0>(reshape(repmat(gamma, n),Shape1(n * c)), data.shape_))* broadcast<0>(1.0f /F<mshadow_op::square_root>(var + param.eps), data.shape_)+ broadcast<0>(gvar, data.shape_)* scale * 2.0f* (data - broadcast<0>(mean, data.shape_))+ broadcast<0>(gmean,data.shape_) * scale);}`

InstanceNorm numpy 实现

`import numpy as npimport torcheps = 1e-05batch = 4channel = 2height = 32width = 32input = np.random.random(size=(batch, channel, height, width)).astype(np.float32)# gamma 初始化为1# beta 初始化为0，所以忽略了gamma = np.ones((1, channel, 1, 1),dtype=np.float32)# 随机生成输出梯度gout = np.random.random(size=(batch, channel, height, width))\.astype(np.float32)# 用numpy计算前向的结果mean_np = np.mean(input, axis=(2, 3), keepdims=True)in_sub_mean = input - mean_npvar_np = np.mean(np.square(in_sub_mean),axis=(2, 3), keepdims=True)invar_np = 1.0 / np.sqrt(var_np + eps)out_np = in_sub_mean * invar_np * gamma# 用numpy计算输入梯度scale = 1.0 / (height * width)# 对应输入梯度公式第三项gvar =gout * gamma * in_sub_mean *-0.5 * np.power(var_np + eps, -1.5)gvar = np.sum(gvar, axis=(2, 3),keepdims=True)# 对应输入梯度公式第二项gmean = np.sum(gout * gamma,axis=(2, 3), keepdims=True)gmean *= -invar_nptmp = scale * np.sum(-2.0 * in_sub_mean,axis=(2, 3), keepdims=True)gmean += tmp * gvar# 对应输入梯度公式三项之和gin_np =gout * gamma * invar_np+ gvar * scale * 2.0 * in_sub_mean+ gmean * scale# pytorch 的实现p_input_tensor =torch.tensor(input, requires_grad=True)trans = torch.nn.InstanceNorm2d(channel, affine=True, eps=eps)p_output_tensor = trans(p_input_tensor)p_output_tensor.backward(torch.Tensor(gout))# 与 pytorch 对比结果print(np.allclose(out_np,p_output_tensor.detach().numpy(),atol=1e-5))print(np.allclose(gin_np,p_input_tensor.grad.numpy(),atol=1e-5))# 命令行输出# True# True`

[1]https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220

[2]https://kevinzakka.github.io/2016/09/14/batch_normalization/

[3]https://www.zhihu.com/question/68730628

[4]https://arxiv.org/pdf/1607.08022.pdf

[5]https://arxiv.org/pdf/1502.03167v3.pdf

[6]https://arxiv.org/pdf/1803.08494.pdf