Press "Enter" to skip to content

Linformer: 线性复杂度的Attention

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

Overall

 

跟Longformer一样,Linformer也是为了解决Transformer中的Attention部分随着序列长度而有N^2复杂度的问题。

 

论文标题很exciting,但是实际做法却很简洁直接,就是在Attention计算的时候K和V部分加了一个线性映射映射到低维空间。当低维空间的大小是固定的时候,就达到了线性复杂度。

 

与简单直接的做法不同,论文中花了很大的篇幅去对映射到低维空间的做法做了证明。

 

观察

 

在Wiki103和IMDB两个数据集上,在Roberta-large预训练好的模型上计算出Attention矩阵。然后做奇异值分解,然后从下图左两图中可以看到,通过奇异值的累积,可以看到,前128维的奇异值累计值已经占了到了0.9左右。

 

 

而在右图中可以看到,越高层,128个奇异值累积值就越高。在第11层,128个奇异值累积起来达到了0.96。

 

因而说明了,虽然Attention的计算结果是一个 N x N 的矩阵,但其实一个低秩矩阵比如 N x 128 可能就已经足够存储attention的所有信息。

 

定理一

 

首先回顾一下Attention的计算,如下图所示。Transformer中的Attention都是多头的,对于第i个头来说,计算如下。

 

 

注意,上面的公式表达跟我们在之前文章中写的略有不同,这里Q,K,V成了原始的embedding,W Q , W K 和W V 是转换矩阵。

 

因此,论文提出了一个定理,如下图所示。数字符号比较繁杂,我用汉语再翻译一下,就是对于任意的Q,K,V和W Q , W K 和W V ,存在一个低秩矩阵P,使得对于VW V 中的任何一个列向量w,满足下面这个式子。更具体一点,就是用低秩矩阵对w做转换,其损失相对于用原始矩阵,被控制在一个可以接受的范围内,此时低秩矩阵的秩是log(n)。

 

 

证明我就不解释了,我们主要关注的是这个idea以及idea所产生的效果。对数学感兴趣可以直接去翻论文。

 

其实这里我有一个疑问,如果低秩矩阵的秩是logn,那幺这个算法的复杂度应该是nlog(n)而不是线性?

 

有了这个方法之后,其实一个直接的手段就是使用SVD对矩阵做近似,这样复杂度就可以变成O(nk),k为采用的低秩矩阵的秩。

 

但是runtime这样做,还需要每次先对大矩阵做SVD,不划算。

 

 

训练时分解矩阵

 

根据上面所说,在inference的时候去做SVD更费事,所以需要在训练时做好。而做的方式就是在key和value上再各自加入一个线性变换。如下图所示:

 

 

上图中的右上部分还画出了不同的k,inference时间和序列长度的关系。可以看到,不管k是多少,Linformer的曲线都是平的。

 

公式如下,E是K上的转换,F是V上的转换。

 

 

定理二

 

针对上面的做法,论文又提出了定理二,对k的下界进行了理论上的限定。论证部分大家感兴趣可以去看原始论文。

 

 

技巧

 

上面模型部分添加了两个线性转换层。在这两个层上,其实还有很多技巧:

 

参数共享,论文提出了三种共享方式:

 

A. E, F在每一层上的各个头之间共享

 

B. 在A的基础上,E,F相等。

 

C. 在B的基础上,每一层相等。

 

不统一的映射维度,即对于不同的head和层次,映射的维度可以不同。当然,这会影响参数的共享,不同维度的映射参数不能再共享。

 

广义映射: 除了线性映射之外,还可以是其他的方式,比如pooling,卷积等。

 

实验

 

对MLM任务的训练结果如下,从a和b图可以看到,k越大效果越好,但它们和标准的transformer其实差别不大。

 

 

在下游任务上结果如下,也是和标准transformer类似的效果。

 

 

而在内存和速度上的提升,则在下图,左图是速度提升,右图是内存提升。可以看到,序列长度越长,k越小,提升越大。

 

 

总结与思考

 

这篇论文是一个观察法做优化的绝好案例,从对attention的SVD分解到映射层的添加水到渠成。但标题原因还是导致论文有些 言过其实 。主要是因为:

 

方案在长度比较长的时候才能显现为例,而在原始的bert上,长度512,此时如果k=128, 那幺相当于内存占用量由512 * 512 变成128 * 512。

 

另一方面, Linformer在长度比较长的时候会更加有效,但论文却只做了性能和内存的比较,没有做在较长序列的情况下,Linformer在下游任务上的优势实验,虽然Roberta做不了baseline,但起码可以和Reformer,longformer比较。

 

证明部分有些奇怪,没有见到明确的线性的证明。

 

或许是我数学水平有限,k=5log(nd) / (ε^2 – ε^3) 我理解不是线性。(有理解不同的可以私信我)

 

对于序列较短的加速需求而言,还是MobileBert更靠谱一些。

 

参考

 

[1]. Wang, Sinong, et al. “: Self-Attention with Linear Complexity.” arXiv preprint arXiv:2006.04768 (2020).

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注