Press "Enter" to skip to content

动态可视化:一步步拆解LSTM和GRU

编者按:关于LSTM,之前我们已经出过不少文章,其中最经典的一篇是chrisolah的《一文详解LSTM网络》,文中使用的可视化图片被大量博文引用,现在已经随处可见。但正如短视频取代纯文字阅读是时代的趋势,在科普文章中,用可视化取代文字,用动态图取代静态图,这也是如今使知识更易于被读者吸收的常规操作。
今天,论智给大家带来的是AI语音助理领域的机器学习工程师Michael Nguyen撰写的一篇LSTM和GRU的动态图解:对于新手,它更直观易懂;对于老手,这些新图绝对值得收藏。

在这篇文章中,我们将从LSTM和GRU背后的知识开始,逐步拆解它们的内部工作机制。如果你想深入了解这两个网络的原理,那么这篇文章就是为你准备的。

问题:短期记忆

如果说RNN有什么缺点,那就是它只能传递短期记忆。当输入序列够长时,RNN是很难把较早的信息传递到较后步骤的,这意味着如果我们准备了一段长文本进行预测,RNN很可能会从一开始就遗漏重要信息。
出现这个问题的原因是在反向传播期间,RNN的梯度可能会消失。我们都知道,网络权重更新依赖梯度计算,RNN的梯度会随着时间的推移逐渐减小,当序列足够长时,梯度值会变得非常小,这时权重无法更新,网络自然会停止学习。

梯度更新规则

根据上图公式:新权重=权重-学习率×梯度。已知学习率是个超参数,当梯度非常小时,权重和新权重几乎相等,这个层就停止学习了。由于这些层都不再学习,RNN就会忘记在较长序列中看到的内容,只能传递短期记忆。

解决方案:LSTM和GRU

LSTM和GRU都是为了解决短期记忆这个问题而创建的。它们都包含一种名为“控制门”的内部机制,可以调节信息流:

这些门能判断序列中的哪些数据是重要的,哪些可以不要,因此,它就可以沿着长序列传递相关信息以进行预测。截至目前,基于RNN的几乎所有实际应用都是通过这两个网络实现的,无论是语音识别、语音合成,还是文本生成,甚至是为视频生成字幕。
在下文中,我们会详细介绍它们背后的具体思路。

人类的记忆

让我们先从一个思维实验开始。双11快到了,假设你想买几袋麦片当早餐,现在正在浏览商品评论。评论区的留言很多,你的阅读目的是判断评论者是好评还是差评:

以上图评论为例,当你一目十行地读过去时,你不太会关注“this”“give”“all”“should”这些词,相反地,大脑会下意识被“amazing”“perfectly balanced breakfast”这些重点词汇吸引。纠结了一晚上,最后你下单了。第二天,你朋友问起你为什么要买这个牌子的麦片,这时你可能连上面这些重点词都忘光了,但你会记得评论者最重要的观点:“will definitely be buying again”(肯定会再光顾)。

就像上图展示的,那些不重要的词仿佛一读完就被我们从脑海中清除了。而这基本就是LSTM和GRU的作用,它们可以学会只保留相关信息进行预测,并忘却不相关的数据。

RNN综述

为了理解LSTM和GRU是怎么做到这一点的,我们先回顾一下它们的原型RNN。下图是RNN的工作原理,输入一个词后,这个词会被转换成机器可读的向量;同理,输入一段文本后,RNN要做的就是按照顺序逐个处理向量序列。

按顺序逐一处理

我们都知道,RNN拥有“记忆”能力。处理向量时,它会把先前的隐藏状态传递给序列的下一步,这个隐藏状态就充当神经网络记忆,它包含网络以前见过的先前数据的信息。

将隐藏状态传递给下一个时间步

那么这个隐藏状态是怎么计算的?让我们看看RNN的第一个cell。如下图所示,首先,它会把输入x和上一步的隐藏状态组合成一个向量,使这个的向量包含当前输入和先前输入的信息;其次,向量经tanh激活,输出新的隐藏状态。

Tanh激活
激活函数Tanh的作用是调节流经网络的值,它能把值始终约束在-1到1之间。

激活函数Tanh

当向量流经神经网络时,由于各种数学运算,它会经历许多次变换。假设每流经一个cell,我们就把值乘以3,如下图所示,这个值很快就会变成天文数字,导致其它值看起来微不足道。

不用Tanh进行调节

而使用了Tanh函数后,如下图所示,神经网络能确保值保持在-1和1之间,从而调节输出。

用Tanh进行调节

以上就是一个最基础的RNN,它的内部构造很简单,但具备从先前信息推断之后将要发生的事的能力。也正是因为简单,它所需的计算资源比LSTM和GRU这两个变体少得多。

LSTM

从整体上看,LSTM具有和RNN类似的流程:一边向前传递,一边处理传递信息的数据。它的不同之处在于cell内的操作:它们允许LSTM保留或忘记信息。

LSTM的cell

核心概念
LSTM的核心概念是cell的状态和各种控制门。其中前者是一个包含多个值的向量,它就像神经网络中的“高速公路”,穿行在序列链中一直传递相关信息——我们也可以把它看作是神经网络的“记忆”。从理论上来说,cell状态可以在序列的整个处理过程中携带相关信息,它摆脱了RNN短期记忆的问题,即便是较早期的信息,也能被用于较后期的时间步。
而当cell状态在被不断传递时,每个cell都有3个不同的门,它们是不同的神经网络,主要负责把需要的信息保留到cell中,并移除无用信息。
Sigmoid
每个门都包含sigmoid激活,它和Tanh的主要区别是取值范围在0到1之间,而不是-1到1。这个特点有助于在cell中更新、去除数据,因为任何数字乘以0都是0(遗忘),任何数字乘以1都等于它本身(保留)。由于值域是0到1,神经网络也能计算、比较哪些数据更重要,哪些更不重要。

激活函数Sigmoid

遗忘门
首先,我们来看3个门中的遗忘门。这个门决定应该丢弃哪些信息。当来自先前隐藏状态的信息和来自当前输入的信息进入cell时,它们经sigmoid函数激活,向量的各个值介于0-1之间。越接近0意味着越容易被忘记,越接近1则越容易被保留。

遗忘门的操作

输入门
输入门是我们要看的第二个门,它是更新cell状态的重要步骤。如下图所示,首先,我们把先前隐藏状态和当前输入传递给sigmoid函数,由它计算出哪些值更重要(接近1),哪些值不重要(接近0)。其次,同一时间,我们也把原隐藏状态和当前输入传递给tanh函数,由它把向量的值推到-1和1之间,防止神经网络数值过大。最后,我们再把tanh的输出与sigmoid的输出相乘,由后者决定对于保持tanh的输出,原隐藏状态和当前输入中的哪些信息是重要的,哪些是不重要的。

输入门的操作

cell状态
到现在为止,我们就可以更新cell状态了。首先,将先前隐藏状态和遗忘门输出的向量进行点乘,这时因为越不重要的值越接近0,原隐藏状态中越不重要的信息也会接近0,更容易被丢弃。之后,利用这个新的输出,我们再把它和输入门的输出点乘,把当前输入中的新信息放进cell状态中,最后的输出就是更新后的cell状态。

计算cell状态

输出门
最后是输出门,它决定了下一个隐藏状态应该是什么。细心的读者可能已经发现了,隐藏状态和cell状态不同,它包含有关先前输入的信息,神经网络的预测结果也正是基于它。如下图所示,首先,我们将先前隐藏状态和当前输入传递给sigmoid函数,其次,我们再更新后的cell状态传递给tanh函数。最后,将这两个激活函数的输出相乘,得到可以转移到下一时间步的新隐藏状态。

总而言之,遗忘门决定的是和先前步骤有关的重要信息,输入门决定的是要从当前步骤中添加哪些重要信息,而输出门决定的是下一个隐藏状态是什么。
代码演示
对于更喜欢读代码的读者,下面是一个Python伪代码示例:

python伪代码
  1. 首先,把先前隐藏状态prev_ht和当前输入input合并成combine
  2. 其次,把combine输入遗忘层,决定哪些不相关数据需要被剔除
  3. 第三,用combine创建候选层,其中包含能被添加进cell状态的可能值
  4. 第四,把combine输入输入层,决定把候选层中哪些信息添加进cell状态
  5. 第五,更新当前cell状态
  6. 第六,把combine输入输出层,计算输出
  7. 最后,把输出和当前cell状态进行点乘,得到更新后的隐藏状态

如上所述,LSTM网络的控制流程不过是几个张量操作和一个for循环而已。

GRU

现在我们已经知道LSTM背后的工作原理了,接下来就简单看一下GRU。GRU是新一代的RNN,它和LSTM很像,区别是它摆脱了cell状态,直接用隐藏状态传递信息。GRU只有两个门:重置门和更新门。

GRU

更新门
更新门的作用类似LSTM的遗忘门和输入门,它决定要丢弃的信息和要新添加的信息。
重置门
重置门的作用是决定要丢弃多少先前信息。
相比LSTM,GRU的张量操作更少,所以速度也更快。但它们之间并没有明确的孰优孰劣,只有适不适合。

小结

以上就是LSTM和GRU的动态图解。总而言之,它们都是为了解决RNN短期记忆的问题而创建的,现在已经被用于各种最先进的深度学习应用,如语音识别、语音合成和自然语言理解等。感谢你的阅读!

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注