Press "Enter" to skip to content

变分自编码器(五):VAE + BN = 更好的VAE

本文我们继续之前的变分自编码器系列,分析一下如何防止NLP中的VAE模型出现“KL散度消失(KL Vanishing)”现象。本文主要的参考文献是最近的论文 《A Batch Normalized Inference Network Keeps the KL Vanishing Away》 ,并补充了一些自己的描述。

 

值得一提的是,本文最后得到的方案相当简洁—— 往编码输出层加入BN ——但确实很有效,因此值得正在研究相关问题的读者一试。如果按照笔者的看法,它甚至可以成为VAE模型的“标配”。

 

让我们简单回顾一下VAE模型,它的训练流程大概可以图示为

 

 

VAE训练流程图示

 

写成公式就是

 

$$\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\ln q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]

 

\end{equation}$$

 

其中第一项就是重构项,$\mathbb{E}_{z\sim p(z|x)}$是通过重参数来实现;第二项则称为KL散度项,这是它跟普通自编码器的显式差别。更详细的符号含义可以参考 《变分自编码器(二):从贝叶斯观点出发》

 

在NLP中,句子被编码为离散的整数ID,所以$q(x|z)$是一个离散型分布,可以用万能的“条件语言模型”来实现,因此理论上$q(x|z)$可以精确地拟合生成分布,问题就出在$q(x|z)$太强了,训练时重参数操作会来噪声,噪声一大,$z$的利用就变得困难起来,所以它干脆不要$z$了,退化为无条件语言模型(依然很强),$KL(p(z|x)\Vert q(z))$则随之下降到0,这就出现了KL散度消失现象。

 

这种情况下的VAE模型并没有什幺价值:KL散度为0说明编码器输出的是0向量,而解码器则是一个普通的语言模型。而我们使用VAE通常来说是看中了它无监督构建编码向量的能力,所以要应用VAE的话还是得解决KL散度消失问题。事实上从2016开始,有不少工作在做这个问题,相应地也提出了很多方案,比如退火策略、换先验分布等,读者Google一下“KL Vanishing”就可以找到很多文献了,这里不一一溯源。

 

本文的方案则是直接针对KL散度项入手,简单有效而且没什幺超参数。其思想很简单:

 

KL散度消失不就是KL散度项变成0吗?我调整一下编码器输出,让KL散度有一个大于零的下界,这样它不就肯定不会消失了吗?

 

这个简单的思想的直接结果就是:在$\mu_z$后面加入BN层,如图

 

 

往VAE里加入BN

 

为什幺会跟BN联系起来呢?我们来看KL散度项的形式:

 

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) = \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\Big(\mu_{i,j}^2 + \sigma_{i,j}^2 – \log \sigma_{i,j}^2 – 1\Big)\end{equation}

 

上式是采样了$b$个样本进行计算的结果,而编码向量的维度则是$d$维。由于我们总是有$e^x \geq x + 1$,所以$\sigma_{i,j}^2 – \log \sigma_{i,j}^2 – 1 \geq 0$,因此

 

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) \geq \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\mu_{i,j}^2 = \frac{1}{2}\sum_{j=1}^d \left(\frac{1}{b} \sum_{i=1}^b \mu_{i,j}^2\right)\end{equation}

 

留意到括号里边的量,其实它就是$\mu_z$在batch内的二阶矩,如果我们往$\mu_z$加入BN层,那幺大体上可以保证$\mu_z$的均值为$\beta$,方差为$\gamma^2$($\beta,\gamma$是BN里边的可训练参数),这时候

 

\begin{equation}KL\big(p(z|x)\big\Vert q(z)\big) \geq \frac{d}{2}\left(\beta^2 + \gamma^2\right)\end{equation}

 

所以只要控制好$\beta,\gamma$,就可以让KL散度项有个正的下界,因此就不会出现KL散度消失现象了。

 

这样一来,KL散度消失现象跟BN就被巧妙地联系起来了,通过BN来“杜绝”了KL散度消失的可能性。更妙的是,加入BN层不会与原来的训练目标相冲,因为VAE中我们通常假设的先验分布就是标准高斯分布,它满足均值为0、方差为1的特性,而加入BN后$\sigma_z$的均值为$\beta$、方差为$\gamma^2$,所以直接固定$\beta=0,\gamma=1$就可以兼容原来的训练目标了。

 

本文简单分析了VAE在NLP中的KL散度消失现象,并介绍了通过BN层来防止KL散度消失的方案。这是一种简洁有效的方案,不单单是原论文,笔者私下也做了简单的实验,结果确实也表明了它的有效性,值得各位读者试用。

 

转载到请包括本文地址: https://kexue.fm/archives/7381

 

更详细的转载事宜请参考: 《科学空间FAQ》

 

如果您还有什幺疑惑或建议,欢迎在下方评论区继续讨论。

 

如果您觉得本文还不错,欢迎/本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

 

如果您需要引用本文,请参考:

 

苏剑林. (2020, May 06). 《 变分自编码器(五):VAE + BN = 更好的VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/7381

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注