## InstanceNorm 梯度公式推导

【GiantPandaCV导语】本文主内容是推导 InstanceNorm 关于输入和参数的梯度公式，同时还会结合 Pytorch 和 MXNet 里面 InstanceNorm 的代码来分析。

## InstanceNorm 与 BatchNorm 的联系

InstanceNorm 训练与预测阶段行为一致，都是利用当前 batch 的均值和方差计算；

BatchNorm 训练阶段利用当前 batch 的均值和方差，测试阶段则利用训练阶段通过移动平均统计的均值和方差；

https://arxiv.org/pdf/1803.08494.pdf

loss 函数的符号约定为

## 主流框架实现代码解读

### Pytroch 前向传播实现

``````Tensor instance_norm(
const Tensor& input,
const Tensor& weight/* optional */,
const Tensor& bias/* optional */,
const Tensor& running_mean/* optional */,
const Tensor& running_var/* optional */,
bool use_input_stats,
double momentum,
double eps,
bool cudnn_enabled) {
// ......
std::vector<int64_t> shape =
input.sizes().vec();
int64_t b = input.size(0);
int64_t c = input.size(1);
// shape 从 (b, c, h, w)
// 变为 (1, b*c, h, w)
shape[1] = b * c;
shape[0] = 1;
// repeat_if_defined 的解释见下文
Tensor weight_ =
repeat_if_defined(weight, b);
Tensor bias_ =
repeat_if_defined(bias, b);
Tensor running_mean_ =
repeat_if_defined(running_mean, b);
Tensor running_var_ =
repeat_if_defined(running_var, b);
// 改变输入张量的形状
auto input_reshaped =
input.contiguous().view(shape);
// 计算实际调用的是 batchnorm 的实现
// 所以可以理解为什幺 pytroch
// 前端 InstanceNorm2d 的接口
// 与 BatchNorm2d 的接口一样
auto out = at::batch_norm(
input_reshaped,
weight_, bias_,
running_mean_,
running_var_,
use_input_stats,
momentum,
eps, cudnn_enabled);
// ......
return out.view(input.sizes());
}
``````

`repeat_if_defined`

`https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27`

``````static inline Tensor repeat_if_defined(
const Tensor& t,
int64_t repeat) {
if (t.defined()) {
// 把 tensor 按第0维度复制 repeat 次
return t.repeat(repeat);
}
return t;
}
``````

### MXNet 反向传播实现

`https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112`

``````template<typename xpu>
void InstanceNormBackward(
const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
// ......
const InstanceNormParam& param =
nnvm::get<InstanceNormParam>(
attrs.parsed);
Stream<xpu> *s =
ctx.get_stream<xpu>();
// 获取输入张量的形状
mxnet::TShape dshape =
inputs[3].shape_;
// ......
int n = inputs[3].size(0);
int c = inputs[3].size(1);
// rest_dim 就等于上文的 M
int rest_dim =
static_cast<int>(
inputs[3].Size() / n / c);
Shape<2> s2 = Shape2(n * c, rest_dim);
Shape<3> s3 = Shape3(n, c, rest_dim);
// scale 就等于上文的 1/M
const real_t scale =
static_cast<real_t>(1) /
static_cast<real_t>(rest_dim);
// 获取输入张量
Tensor<xpu, 2> data = inputs[3]
.get_with_shape<xpu, 2, real_t>(s2, s);
// 保存输入梯度
Tensor<xpu, 2> gdata = outputs[kData]
.get_with_shape<xpu, 2, real_t>(s2, s);
// 获取参数 gamma
Tensor<xpu, 1> gamma =
inputs[4].get<xpu, 1, real_t>(s);
// 保存参数 gamma 梯度计算结果
Tensor<xpu, 1> ggamma = outputs[kGamma]
.get<xpu, 1, real_t>(s);
// 保存参数 beta 梯度计算结果
Tensor<xpu, 1> gbeta = outputs[kBeta]
.get<xpu, 1, real_t>(s);
// 获取输出梯度
Tensor<xpu, 2> gout = inputs[0]
.get_with_shape<xpu, 2, real_t>(
s2, s);
// 获取前向计算好的均值和方差
Tensor<xpu, 1> var =
inputs[2].FlatTo1D<xpu, real_t>(s);
Tensor<xpu, 1> mean =
inputs[1].FlatTo1D<xpu, real_t>(s);
// 临时空间
Tensor<xpu, 2> workspace = //.....
// 保存均值的梯度
Tensor<xpu, 1> gmean = workspace[0];
// 保存方差的梯度
Tensor<xpu, 1> gvar = workspace[1];
Tensor<xpu, 1> tmp = workspace[2];
// 计算方差的梯度，
// 对应上文输入梯度公式的第三项
// gout 对应输出梯度
gvar = sumall_except_dim<0>(
reshape(repmat(gamma, n),
Shape1(n * c)), data.shape_)) *
mean, data.shape_)) * -0.5f *
var + param.eps, data.shape_),
-1.5f)
);
// 计算均值的梯度，
// 对应上文输入梯度公式的第二项
gmean = sumall_except_dim<0>(
reshape(repmat(gamma, n),
Shape1(n * c)), data.shape_));
gmean *=
var + param.eps);
tmp = scale * sumall_except_dim<0>(
mean, data.shape_)));
tmp *= gvar;
gmean += tmp;
// 计算 beta 的梯度
// 记得s3 = Shape3(n, c, rest_dim)
// 那幺swapaxis<1, 0>(reshape(gout, s3))
// 就表示首先把输出梯度 reshape 成
// (n, c, rest_dim)，接着交换第0和1维度
// (c, n, rest_dim)，最后求除了第0维度
// 之外其他维度的和，
// 也就和 beta 的求导公式对应上了
Assign(gbeta, req[kBeta],
sumall_except_dim<0>(
swapaxis<1, 0>(reshape(gout, s3))));
// 计算 gamma 的梯度
// swapaxis<1, 0> 的作用与上面 beta 一样
Assign(ggamma, req[kGamma],
sumall_except_dim<0>(
swapaxis<1, 0>(
reshape(gout *
data.shape_))
var + param.eps,
data.shape_
)
), s3
)
)
)
);
// 计算输入的梯度，
// 对应上文输入梯度公式三项的相加
Assign(gdata, req[kData],
reshape(repmat(gamma, n),
Shape1(n * c)), data.shape_))
var + param.eps), data.shape_)
* scale * 2.0f
mean, data.shape_))
data.shape_) * scale);
}
``````

## InstanceNorm numpy 实现

``````import numpy as np
import torch
eps = 1e-05
batch = 4
channel = 2
height = 32
width = 32
input = np.random.random(
size=(batch, channel, height, width)).astype(np.float32)
# gamma 初始化为1
# beta 初始化为0，所以忽略了
gamma = np.ones((1, channel, 1, 1),
dtype=np.float32)
# 随机生成输出梯度
gout = np.random.random(
size=(batch, channel, height, width))\
.astype(np.float32)
# 用numpy计算前向的结果
mean_np = np.mean(
input, axis=(2, 3), keepdims=True)
in_sub_mean = input - mean_np
var_np = np.mean(
np.square(in_sub_mean),
axis=(2, 3), keepdims=True)
invar_np = 1.0 / np.sqrt(var_np + eps)
out_np = in_sub_mean * invar_np * gamma
# 用numpy计算输入梯度
scale = 1.0 / (height * width)
# 对应输入梯度公式第三项
gvar =
gout * gamma * in_sub_mean *
-0.5 * np.power(var_np + eps, -1.5)
gvar = np.sum(gvar, axis=(2, 3),
keepdims=True)
# 对应输入梯度公式第二项
gmean = np.sum(
gout * gamma,
axis=(2, 3), keepdims=True)
gmean *= -invar_np
tmp = scale * np.sum(-2.0 * in_sub_mean,
axis=(2, 3), keepdims=True)
gmean += tmp * gvar
# 对应输入梯度公式三项之和
gin_np =
gout * gamma * invar_np
+ gvar * scale * 2.0 * in_sub_mean
+ gmean * scale

# pytorch 的实现
p_input_tensor =
trans = torch.nn.InstanceNorm2d(
channel, affine=True, eps=eps)
p_output_tensor = trans(p_input_tensor)
p_output_tensor.backward(
torch.Tensor(gout))
# 与 pytorch 对比结果
print(np.allclose(out_np,
p_output_tensor.detach().numpy(),
atol=1e-5))
print(np.allclose(gin_np,
atol=1e-5))
# 命令行输出
# True
# True
``````

## 参考资料：

[1] https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220

[2] https://kevinzakka.github.io/2016/09/14/batch_normalization/

[3] https://www.zhihu.com/question/68730628

[4] https://arxiv.org/pdf/1607.08022.pdf

[5] https://arxiv.org/pdf/1502.03167v3.pdf

[6] https://arxiv.org/pdf/1803.08494.pdf