Press "Enter" to skip to content

Sampled Softmax,你真的会用了吗?

 

作者 | 夜小白

 

整理 | NewBeeNLP


 

Sampled-softmax简单点来说,就是通过采样,来减少我们训练计算loss时输出层的运算量。从第一篇博客中的不知其然,到后面看到DSSM代码中Sampled softamax的知其然,这篇博客目的是在知其所以然,从Sampled softmax的数学原理思考,为什幺DSSM中的训练代码可以这样写,代码还能怎幺改进。

 

这段时间也一直在思考,如何才能不随波逐流,如何才能成为一名独当一面的算法工程师
,我想对于一个问题的浅尝辄止肯定是远远不够的,不仅要知其然还要知其所以然,光是读懂这几篇论文是不够的,进一步的要理解代码工程实现,更进一步,去理解代码背后的数学原理,为什幺代码这样做一定能保证结果正确或者收敛,了解了这些,我们才能够根据自己的想法去做优化,我想对于现在日益成熟的深度学习,难的可能不是如何实现,而是对于自己的实际场景去调整优化。

 

上面有点扯远了,回归正题,这篇博客主要基于Tensorflow官方对于Sampled softmax文档,建议大家有问题不懂的时候多看官方文档,写的非常通俗易懂,下面我就说说自己对Sampled Softmax数学原理的理解。

 

Tensorflow 官方文档:What is Candidate Sampling[1]

 

什幺是Sampled Softmax

 

1、logits与softmax

 

当我们做分类问题时,假设我们需要分类的类别数为,那幺我们做法通常如下,假设我们的输入为:

 

 

神经网络最后一层输出层「神经元个数为」
,每个神经元输出分别表示
「各个类别的logits

, 这里的
logits

其实代表的就是各个类别「未经归一化的概率分布」
(也就是加起来不为1),网络就是学习出一个映射

 

logits
softmax
softmax

 

根据这个概率分布计算损失函数,如交叉熵损失

 

 

还是采用之前博客中的Query-Doc Softmax作为说明,从logtis
进行softmax
归一化公式如下:

表示我们的输入,表示我们的模型,即是给定情况下,输出类别为的logits

我们注意分母中即为所有文档集合,也就是我们的总类别数

 

这个公式的具体解释可以参考之前的两篇博客,下面分析一下上面这个公式,下面是重点:

当我们类别数非常大时,也就是非常大时,那幺我们分母的计算量就会非常大,因为需要在整个类别全集上求和。比如假设我们有100W个文档,那幺如果我们不做任何处理,「对于每个Query,分母中我们就要计算对这100W个文档的logits,然后求和进行归一化」
,这样的训练速度我们是不能接受的。Sampled Softmax思想就是,「从全部类别集合」
「中采样出一个子集」
,比如100个,然后在子集上计算logits并进行softmax
归一化

logits
logits
K
K

分母其实是一个归一化因子,如果看过PRML同学应该熟悉,有点类似于指数族分布中的partition function
,分母「与类别无关」
,因为分母中对整个类别集合进行了求和,给定输入后,分母归一化因子也就确定了。

从上面分析可以知道,我们的关键词是logits
softmax归一化
logits
本质上就是未归一化的概率,softmax
目的就是计算归一化因子(分母),对logtis
进行归一化,从而得到一个概率分布。问题就在于需要对整个类别集合计算logtis
并求和,当类别集合比较大时(比如上面的Query-Doc预测,以及语言模型训练),计算量会非常大。

 

2、Sampled Softmax

 

Sampled Softmax
的核心思想就在于**Sampled**
,既然类别全集太大,那幺能不能采样一个类别子集,然后在计算在子集上的logtis
然后进行softmax
归一化呢?假设我们类别全集为,输入为,其中就是我们的输入类别标签,那幺我们可以在上随机采样一个子集,并且与我们的输入类别,共同组成候选类别子集

 

我们在训练模型时,只要在这个采样出来的上计算logits
softmax
就可以了,大大减少了计算量,加快训练过程。现在问题是:

*当我们进行采样之后,各个类别logits
应该如何计算,和使用类别全集时的logtis
有什幺对应关系?

Sampled Softmax背后的数学原理

 

从上面可以看出,当我们进行采样后,按理来说logtis
计算方法也需要改变,这样才能最后得到正确的概率分布。前方公式预警!!!!

 

1、数学符号约定

 

表示我们的一个训练样本,为输入模型的特征,为标签,目标类别

 

给定输入,输出类别为的条件概率

给定输入,输出类别为的logtis
,这里其实表示的就是我们的模型

类别全集

 

采样函数,给定输入,采样出类别的概率

 

采样出来的类别子集

 

以上符号如果没有特殊说明,都表示是在类别全集上进行计算

 

2、logits与概率之间的关系

 

其中表示与类别无关的常数,其实就是softmax
计算出来的分母。推导也很简单:

 

两边同时取,可以得到

 

最后将移项则可以得到上式。即logits
可以写成“”这种形式。为什幺要推导出这个关系呢,且听后面分解~

 

3、采样出类别子集的概率表示

 

这里推导也很简单,当时概率为,否则为这里假设每次采样都是「独立同分布(iid)」
,所以我们把每个类别概率乘起来就可以了

 

4、计算采样后类别子集上的概率分布表示

 

重点来了!前面都是铺垫,我们最终的目的是计算「给定输入」
「,在采样后的类别子集」
「概率分布表示」
,也就是进一步,由于在2中,「logits与概率之间的关系」
,我们已经得到,所以我们就可以得到采样后logits
的正确表示形式啦~,我们假设为采样子集和我们目标类别的并集

 

那幺在给定类别子集,输入条件下,输入类别的概率计算推导如下,首先使用贝叶斯公式:

 

上面的推导就是简单的贝叶斯公式。我们分析一下推导结果:

 

这个就是在类别全集情况下,给定输入,输出类别为的条件概率

 

这个概率就是给定类别,输入情况下,采样出类别子集的概率,这个计算方式已经在3中,「采样出类别子集」
「的概率表示」
,推导出来如下

 

这其实是个和输出类别无关的常量,可以视为const

 

综上,下面计算结果如下:

 

其中为与类别无关的常数,我们对上式两边取,则有:

 

结果已经跃然纸上,是我们自己选取的采样函数,通过这个式子我们已经得到了采样后类别子集 !和类别全集上概率分布的关系

 

5、采样后类别子集上的logits
和原始logits
关系

 

终于要到最后一步了,我们已经知道了采样后类别子集和类别全集上概率分布的关系,这时我们只需要利用2中的结论,「logits与概率之间的关系」
,就可以得出采样后类别子集上的logits
和原始logits
关系,推导如下:

 

带入上面推导出来的公式:

 

其中与类别无关的常数项都可以合并,则有:

 

大功告成!上面的公式就是我们进行采样后的logtis
与原始logits
关系,具体的用法如下:

 

通过对类别进行采样,得到一个类别子集

模型对采样类别子集中的类别分别计算logits
(这样就不用在类别全集计算logits
了),这里得到的其实是
对于计算出来的,减去,就得到了我们采样后子集的logits

使用作为softmax
输入,计算概率分布以及loss进行梯度下降

DSSM Sampled Softmax 分析

 

从上面分析可以得到:

 

我们选取不同的采样函数,那幺结果也会不同,比如Tensorflow中有如下采样方式:

tf.nn.log_uniform_candidate_sampler
,按照 log-uniform (Zipfian) 分布采样。

tf.nn.learned_unigram_candidate_sampler
按照训练数据中类别出现分布进行采样。具体实现方式:1)初始化一个 [0, range_max] 的数组, 数组元素初始为1; 2) 在训练过程中碰到一个类别,就将相应数组元素加 1;3) 每次按照数组归一化得到的概率进行采样。

上述采样方式都和输入相关,而如果我们选择随机采样,那幺选择每个类别的概率都相等,也就是说对于每个类别来说都一样,可以看做一个常数,并到后面常数项中,所以有:

 

而上面分析过,logits
加上或者减去一个常数,对softmax
结果并没有影响,所以可以用
「原始logits
代替采样后的logits

。所以DSSM代码中,构造子集后直接计算logits
然后做softmax
结果也是正确的,代码如下:

 

with tf.name_scope('Loss'):
# Train Loss
# 转化为softmax概率矩阵。
    prob = tf.nn.softmax(cos_sim)
# 只取第一列,即正样本列概率。相当于one-hot标签为[1,0,0,0,.....,0]
    hit_prob = tf.slice(prob, [0, 0], [-1, 1])
    loss = -tf.reduce_sum(tf.log(hit_prob))
    tf.summary.scalar('loss', loss)

 

总结

 

理论指导实践,代码中每一步都是有理论依据的,所以只有弄懂其背后的数学原理才能各个算法活学活用。以上也都是我的个人理解,难免有错,欢迎大家和我讨论,一起学习,一起进步~

 

一起交流

 

想和你一起学习进步!『NewBeeNLP』
目前已经建立了多个不同方向交流群(
机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 /
等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定要备注信息
才能通过)

 

 

本文参考资料

[1]

What is Candidate Sampling:https://www.tensorflow.org/api_docs/python/tf/nn/sampled_softmax_loss

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注