使用循环神经网络能让我们创建一个 AI 学习序列数据,比如文本、视频和音频等。机器学习爱好者 Greg Surma 使用循环神经网络中的 LSTM 网络搭建了一个文本预测模型。能够生成 rap 歌词,看完本文后你也可以有 freestyle。
项目代码地址见文末。
内容目录
- 神经网络
- 循环神经网络(RNN)
- 长短期记忆网络(LSTM)
- 文本预测模型
- 结果
如果想使用循环神经网络预测文本,我们需要了解神经网络是什么以及它们的工作原理。下面我们首先说说神经网络和循环神经网络,已经熟悉这部分知识的朋友可以直接跳到模型搭建部分。
神经网络
我们先讲一个例子,假如想搭建一个能学习我们习惯的预测模型。在这个例子中,我们的生活很简单——晴天去打篮球,雨天就在家读书。

我们的目标就是创建出一个模型,可以根据天气输入预测我们的活动。

我们需要获取模型的权重,即上图红色矩阵中的值,对于给定的输入,这些值将产生所需的输出。
为了实现这些目标,我们需要使用能反应我们实际生活行为的形态数据(天气,活动)训练神经网络。假设我们会坚持“晴天打篮球,雨天读书”的生活习惯,那么我们后面输入模型的只有两对值:(晴天,篮球)和(雨天,书籍)。
经过一些训练后,即提供输入-输出对(监督学习),我们的红色权重矩阵大致如下所示。

可以轻松地验证它。

我们的预测模型正确地将晴天与打篮球以及雨天与读书联系起来。
但是,假如我们搬到加利福尼亚州这种全年阳光明媚的地方该怎么办?为了能让身体和智力平衡发展,我们应该改变以前的生活习惯,运动与读书交叉进行。比如,如果我们昨天打篮球,我们今天会去看书,明天打篮球。我们将坚持这种交替活动的习惯。

随着上述生活习惯发生变化,我们需要相应地修改我们的预测模型。经典神经网络这时就帮不到我们了,因为它们没有考虑输入的方程序列。
这就是循环神经网络的用武之地。
和常规神经网络不同,它们适合处理和学习我们这个例子中的序列数据。
循环神经网络(RNN)
现在我们先不考虑天气输入,只将两个值视为输入和输出——打篮球和读籍。那么我简化 RNN 模型如下所示。

定义了神经网络的结构之后,就可以向模型提供训练数据了。假设我们打篮球和读书交替进行,我们的训练数据将如下所示。
经过一些训练后,我们将得到一个近似的权重矩阵。

最后,这样的模型可以正确地预测,如果我们最近打篮球,那么我们将会阅读一些书籍。同样地,如果我们最近阅读书籍,接下来就会去打篮球。
尽管值不同,但它看起来与上一个示例中的神经网络完全相同。之所以这样,是因为在某种意义上它仍然是一个常规神经网络,并且它的所有基本概念都没有改变,因为它反复将前一次迭代的输出作为下一次迭代的输入。
该模型可以可视化为如下所示的结果,其中每个 RNN 单元是一个神经网络。

上面的例子只是非常简单,易于掌握 RNN 概念的插图。他们的主要特点是,由于他们经常性,他们可以坚持信息。
它很像人类,作为人类,我们拥有以前的知识和记忆,可以帮助我们解决问题,虽然通常我们并没有意识到这一点。即使像阅读这样的微不足道的活动也会在这种现象中受益巨大。在我们阅读时,我们不会独立处理每个字母,因为它们中单独每一个都没有任何意义。字母 ‘a’ 并不比字母 ‘b’ 更好或更差。只有当它们是诸如单词,句子等更高层次结构的一部分时,它们才开始有意义。
据剑桥大学研究,人脑有一种奇异的能力:一个单词中的字母排列顺序并不重要,只要第一个字母和最后一个字母的位置没错,其他字母可以完全随机排列(混乱),而你仍然可以轻易地读懂它。这是因为人类大脑并不读取每个字母,而是将单词看成整体。

RNN 会努力通过其 RNN 架构实现这种能力,该架构能让模型在连续学习步骤期间保持信息一致。
我们来考虑一个想预测句子中缺失单词的文本预测模型。
That which does not kill us makes us (…).
但凡不能杀死你的,最终都会使你(…)。
我想我们大多数人都知道这里的答案是“stronger”(更强大)。
我们怎么知道的?
可能是因为它是有史以来最受欢迎的名言之一,我们在书籍、电影、电视节目上经常见到,记住它了。
幸运的是,对于我们的文本预测模型,会向它提供一个包含这个短语的数据集,它最近出现过,所以模型也能够使用它的记忆来获得有效的解决方案。

我们前面能在模型中找到相关信息,是因为它相对接近,我们不必看得太远。
我们来考虑一下我们的文本预测模型的另一个例子。这一次,我们用达尔文的全部作品为模型提供数据,然后让模型预测给定短语中的缺失单词。
On the Origin of (…)
论(……)的起源
这个谜题比前一个要略微难些,但是那些熟悉达尔文作品的人可能知道缺少的词是“物种”而整个短语是他最著名的关于进化论的一本书的名字。

不幸的是,我们的模型无法正确预测。这个短语并不经常出现在书中,而且它在书名中出现,很难记忆。
对于人类读者来说,如果一条短语出现在书名中,很明显它一定是本书的关键部分,即便它只出现在书名中,和循环神经网络相比,我们也能记住它。
尽管 RNN 无法处理这种长期依赖关系,但还有其他方法可以成功地管理它们。
其中之一是长短期记忆网络(LSTM),我们在本项目会简要分析以及使用它。
长短期记忆(LSTM)
虽然 RNN 的记忆深度有限,但 LSTM 会学习记住什么以及忘记什么。
这能让 LSTM 到达并利用超出 RNN 范围的记忆,但由于感知到它们的重要性,LSTM 能够将其记住。
我们深入了解一下细节!
RNN 网络通常结构简单,具有数据流动的重复模块。标准 RNN 通常由简单的 tanh 层构成。

另一方面,LSTM 包含更复杂的单元。

一开始看起来可能看起来很复杂,但它可以简化为以下易于理解的图表。

如你所见,LSTM 单元有 3 个门。
- 输入门(input gate)决定给定信息是否值得记忆。
- 忘记门(forget gate)决定给定信息是否仍然值得记住。如果没有,则将其删除。
- 输出门(output gate)决定给定信息在给定步骤是否相关并且应该充分使用。
不过这个神秘的决策过程实际上是如何运作的?
每个门都是一个具有相关权重的层。在每一步,它接受一个输入并执行一个 sigmoid 函数,该函数返回 0-1 范围内的值,其中 0 表示不允许任何通过,1 表示让一切通过。
之后,每个层的值将通过反向传播机制进行更新。它允许门随着时间的推移学习哪些信息是重要的,哪些不是。
现在我们已经认识了 RNN 以及 LSTM 背后的高级概念,接下来实现我们的项目,创建一个文本预测模型。
数据
由于我们要实现一个字符级模型,所以会把数据集的所有行拆分为字符列表。之后,我们将检测唯一字符和相应的频率分数。频率得分越低,关联的 char(数据集中)越受欢迎。
然后我们将创建一个词汇表,这是一个以下形状的字典。
{unique_char:frequency_score}
{u'\n': 11, u'\r': 12, u'!': 58, u' ': 0, u'#': 80, u'"': 28, u'$': 77, u"'": 22, u'&': 69, u')': 42,
u'(': 41, u'+': 72, u'*': 79, u'-': 44, u',': 19, u'/': 78, u'.': 43, u'1': 63, u'0': 67, u'3': 64,
u'2': 62, u'5': 70, u'4': 71, u'7': 81, u'6': 74, u'9': 73, u'8': 82, u';': 75, u':': 40, u'?': 38,
u'A': 29, u'C': 39, u'B': 36, u'E': 61, u'D': 51, u'G': 47, u'F': 59, u'I': 23, u'H': 34, u'K': 46,
u'J': 55, u'M': 45, u'L': 48, u'O': 50, u'N': 52, u'Q': 83, u'P': 57, u'S': 35, u'R': 56, u'U': 65,
u'T': 31, u'W': 30, u'V': 54, u'Y': 37, u'X': 76, u'[': 32, u'Z': 68, u']': 33, u'a': 4, u'c': 20,
u'b': 25, u'e': 1, u'd': 14, u'g': 17, u'f': 24, u'i': 6, u'h': 7, u'k': 21, u'j': 49, u'm': 16,
u'l': 10, u'o': 2, u'n': 5, u'q': 66, u'p': 26, u's': 8, u'r': 9, u'u': 13, u't': 3, u'w': 18,
u'v': 27, u'y': 15, u'x': 60, u'z': 53}
下一步,将创建我们的工作张量。可能最容易想到的方法就是获取我们的输入数据集并用其频率分数替换每个字符。
为了帮你更好地理解,我们拿侃爷的 rap 歌词做个例子,为歌词数据集的初始短语创建一个张量。
输入
“Kanye, can I talk to you for a minute?
张量
[28 46 4 5 15 1 19 0 20 4 5 0 23 0 3 4 10 21 0 3 2 0 15 2
13 0 24 2 9 0 4 0 16 6 5 13 3 1 38]
现在我们已经定义了工作张量,接着把它分成几个批次。模型的目标是按照给定输入产生期望输出的方式进行优化。因此我们需要提供输入批次和目标批次。
为简单起见,我们假设序列长度为 5。根据这一点,我们可以将第一个输入批次处理设置为工作张量的 5 个初始元素。
由于我们的模型是字符级的,所以我们的目标是预测前一个字符列表的下一个字符。创建满足此项要求的目标批次的最简单方法,就是按照和输入批次相比 a+1 的方式对张量下标。

我们为所提供的参数和张量长度生成尽可能多的输入/目标对。所有这些对创建一个训练周期。
self.vocabulary = dict(zip(self.chars, range(len(self.chars))))
self.tensor = np.array(list(map(self.vocabulary.get, data)))
self.batches_size = int(self.tensor.size / (self.batch_size * self.sequence_length))
self.tensor = self.tensor[:self.batches_size * self.batch_size * self.sequence_length]
inputs = self.tensor
targets = np.copy(self.tensor)
targets[:-1] = inputs[1:]
targets[-1] = inputs[0]
self.input_batches = np.split(inputs.reshape(self.batch_size, -1), self.batches_size, 1)
self.target_batches = np.split(targets.reshape(self.batch_size, -1), self.batches_size, 1)
定义了数据层之后,将其输入模型。
算法
我们先从一些基本的数据准备/初始化开始。
def rnn():
data_provider = DataProvider(data_dir, BATCH_SIZE, SEQUENCE_LENGTH)
model = RNNModel(data_provider.vocabulary_size, batch_size=BATCH_SIZE, sequence_length=SEQUENCE_LENGTH, hidden_layer_size=HIDDEN_LAYER_SIZE, cells_size=CELLS_SIZE)
with tf.Session() as sess:
summaries = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)
writer.add_graph(sess.graph)
sess.run(tf.global_variables_initializer())
epoch = 0
temp_losses = []
smooth_losses = []
随后,我们创建一个无限循环的训练时期。
while True:
sess.run(tf.assign(model.learning_rate, LEARNING_RATE * (DECAY_RATE ** epoch)))
data_provider.reset_batch_pointer()
state = sess.run(model.initial_state)
在每个周期内,我们遍历每个批次,使用输入/目标对提供我们的模型,最后保持当前状态和损失。
for batch in range(data_provider.batches_size):
inputs, targets = data_provider.next_batch()
feed = {model.input_data: inputs, model.targets: targets}
for index, (c, h) in enumerate(model.initial_state):
feed[c] = state[index].c
feed[h] = state[index].h
iteration = epoch * data_provider.batches_size + batch
summary, loss, state, _ = sess.run([summaries, model.cost, model.final_state, model.train_op], feed)
最后收集统计信息,示例文本和输出日志。
writer.add_summary(summary, iteration)
temp_losses.append(loss)
if iteration % SAMPLING_FREQUENCY == 0:
sample_text(sess, data_provider, iteration)
if iteration % LOGGING_FREQUENCY == 0:
smooth_loss = np.mean(temp_losses)
smooth_losses.append(smooth_loss)
temp_losses = []
plot(smooth_losses, "iterations (thousands)", "loss")
print('{{"metric": "iteration", "value": {}}}'.format(iteration))
print('{{"metric": "epoch", "value": {}}}'.format(epoch))
print('{{"metric": "loss", "value": {}}}'.format(smooth_loss))
epoch += 1
结果
终于,我们可以完成主要目标了——生成 rap 歌词。建议把你喜欢的说唱歌手的歌词收集起来,浓缩为一个文件。
cat song1.txt song2.txt > text_predictor/data/<dataset>/input.txt
然后运行
python text_predictor.py <dataset>
包含 rap 歌词以及训练图的输出文件会自动生成在数据集的目录中。
应该能得到像下面这样的结果。
使用下面的参数生成“侃爷”的歌词预测结果。
Selected dataset: kanye
Batch size: 32
Sequence length: 25
Learning rate: 0.01
Decay rate: 0.97
Hidden layer size: 256
Cells size: 2
Tensor size: 330400
Batch size: 32
Sequence length: 25
Batches size: 413

“侃爷”的输入数据集包含来自以下专辑的歌词:“The College Dropout”,“808s&Heartbreak”,“Yeezus”,“Late Registration”,“Graduation”,“My Beautiful Dark Twisted Fantasy”,“Watch The Throne”和“The Life of Pablo”。
迭代:0
9hu71JQ)eA”oqwrAAUwG5Wv7rvM60[*$Y!:1v*8tbkB+k 8IGn)QWv8NR.Spi3BtK[VteRer1GQ,it”kD?XVel3lNuN+G//rI’ Sl?ssm
NbH # Yk2uY”fmSVFah(B]uYZv+2]nsMX(qX9s+Rn+YAM.y/2 Hp9a,ZQOu,dM3.;im$Jca4E6(HS’D
[itYYQG#(gahU(gGoFYi)ucubL3 #iU32 8rdwIG7HJYSpDG*j,5
不出所料,模型最开始时什么都不知道,只生成了一些随机样本。
迭代:1000
Am our 200 shought 2 and but
One we -fuckister do fresh smandles
Juco pick with to sont party agmagle
Then I no meant he don’t ganiscimes mad is so cametie want
What
Mama sumin’ find Abortsimes, man
可以看到模型学会了怎样组成一些单词以及构建文本。每一行都有合适的长度,并以大写字母开头。
迭代:3000
Moss for a kice the mowing?
[Verse 1]
I play this better your pictures at here friends
Ever sip head
High all I wouldn’t really what they made thirise
And clap much
文本预测模型学会了如何组成具有特定结构的数据集,而且相比之前的迭代结果,单词错误更少。
迭代:25000
Through the sky and I did the pain is what what I’m so smart
Call extry lane
Make man flywing yet then you a represent
And paper more day, they just doing with her
This that fast of vision
虽然生成的文本仍然说不通,但是可以看到模型学会了如何生成没有错误的句子。
迭代:207000
Right here, I was mailing for mine, where
Uh, that’s that crank music nigga
That real bad red dues, now you do
Hey, hey, hey, hey
Don’t say you will, if I tried to find your name!
I’m finna talk, we Jine?
Ok, ooh oh!
Bam ba-ah-man, crack music nigga
(We look through to crew and I’m just not taking out
Turned out, this diamondstepragrag crazy
Tell my life war
[Hook]
Breakfies little lights, a South Pac-shirt
Track music nigga
(La la la la la la la lah, la la la la la lah)
经过 20 万次学习迭代后,我们的“侃爷”AI 已经学会怎么唱歌了。
虽然 AI 生成的歌词可能没有传达任何有意义的(至少对于人类而言)消息,但我们可以清楚地看到模型学会了如何正确地模仿所提供的输入数据集的样式。鉴于我们的 RNN,LSTM 模型是从零开始学习一切,最初不懂任何词汇和句子,更不用说英语了,所以 AI 的创作还是非常惊人的!此外,训练 AI 的“侃爷”歌词数据集很小,如果使用更大的数据集,其结果会更好。
项目代码地址:
https://github.com/gsurma/text_predictor
Be First to Comment