$$KL(\tilde{p}(x)p(z|x)\Vert q(z)q(x|z))=\iint \tilde{p}(x)p(z|x)\log \frac{\tilde{p}(x)p(z|x)}{q(x|z)q(z)} dzdx$$

$$\mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]\label{eq:vae}$$

1、引入了均值和方差的概念，加入了重参数操作；
2、加入了KL散度为额外的损失函数。

$$x \quad \to \quad z \quad \to \quad y$$

$$\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{p(z)}dxdz$$

$$p(z) = \int p(z|x)\tilde{p}(x)dx$$

$$-\iint p(z|x)\tilde{p}(x)\log p(y|z)dxdz$$

$$-\iint p(z|x)\tilde{p}(x)\log p(y|z)dxdz + \lambda \iint p(z|x)\tilde{p}(x)\max\left(\log \frac{p(z|x)}{p(z)} – \beta, 0\right)dxdz$$

$$-\iint p(z|x)\tilde{p}(x)\log p(y|z)dxdz + \lambda \iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{p(z)}dxdz$$

\begin{aligned}&\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{p(z)}dxdz\\

=&\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{q(z)}\frac{q(z)}{p(z)}dxdz\\

=&\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{q(z)} + \iint p(z|x)\tilde{p}(x)\log \frac{q(z)}{p(z)}dxdz\\

=&\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{q(z)} + \int p(z)\log \frac{q(z)}{p(z)}dz\\

=&\iint p(z|x)\tilde{p}(x)\log \frac{p(z|x)}{q(z)} – \int p(z)\log \frac{p(z)}{q(z)}dz\\

=&\int \tilde{p}(x) KL\big(p(z|x)\big\Vert q(z)\big) dx – KL\big(p(z)\big\Vert q(z)\big)\\

<&\int \tilde{p}(x) KL\big(p(z|x)\big\Vert q(z)\big) dx\end{aligned}

$$-\iint p(z|x)\tilde{p}(x)\log p(y|z)dxdz + \lambda \int\tilde{p}(x) KL\big(p(z|x)\big\Vert q(z)\big) dx$$

$$\mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(y|z)\big]+\lambda\cdot KL\big(p(z|x)\big\Vert q(z)\big)\Big]\label{eq:vib}$$

1、引入了均值和方差的概念，加入了重参数操作；
2、加入了KL散度为额外的损失函数。

from keras.layers import Layer
import keras.backend as K

class VIB(Layer):
"""变分信息瓶颈层
"""
def __init__(self, lamb, **kwargs):
self.lamb = lamb
super(VIB, self).__init__(**kwargs)
def call(self, inputs):
z_mean, z_log_var = inputs
u = K.random_normal(shape=K.shape(z_mean))
kl_loss = - 0.5 * K.sum(K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), 0))
u = K.in_train_phase(u, 0.)
return z_mean + K.exp(z_log_var / 2) * u
def compute_output_shape(self, input_shape):
return input_shape[0]

https://github.com/bojone/vib/blob/master/cnn_imdb_vib.py

$$z_{\alpha} = \alpha z_1 + (1 – \alpha) z_2,\quad 0 \leq \alpha \leq 1$$

$$z_1 + z_2 \sim\mathcal{N}(\mu_1+\mu_2, \sigma_1^2+\sigma_2^2)$$

$$\alpha z_1 + \beta z_2 \sim \mathcal{N}(0, \alpha^2+\beta^2)$$

$$z_{\theta}=z_1\cdot\cos\theta + z_2\cdot\sin\theta$$