Press "Enter" to skip to content

论文浅尝 – ICML2020 | 跨域对齐的图最优运输算法

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

陈卓,浙江大学在读博士,主要研究方向为图神经网络和知识图谱表示学习。

 

 

论文链接: https://arxiv.org/pdf/2006.14744

 

代码: https://github.com/LiqunChen0606/Graph-Optimal-Transport

 

发表会议:ICML 2020

 

动机

 

该论文的出发点基于前人工作的局限,认为当前存在的跨域对齐方法主要是采用各种先进的注意力机制来模拟软对齐,但是传统的注意力机制是由特定下游任务的loss进行监督和引导的,而没有明确考虑对齐本身的的训练信号。并且,往往学习到的注意力矩阵会比较稠密,缺乏可解释性。所以作者提出了图最优运输算法这样一个新的框架,通过把最优运输应用在图匹配上来处理跨域问题。

 

同时,这个算法与现有的神经网络模型具有很好的兼容性,可以直接作为drop-in正则化项加入到原来的模型中。通过这样一个通用的正则化系数,在两个域对齐程度低的的pair上施加更多的惩罚,这对于机器翻译,图像注释,以及图像-文本跨模态检索等需要匹配的场景,效果提升是比较make sense的。

 

最后,这个论文的很大一个亮点在于通用性,作者在5个task上对于不同的模型做了相关实验,使用了GOT方法后全部取得了效果提升。后面大部分篇幅也用在实验上。

 

背景设定

 

这里的跨域对齐可能与跨知识图谱数据库对齐的不太一样。因为本文所指的跨域是特指跨模态的。对于两个不同的domain Dx和Dy, 分别考虑其中的一个数据集如X tilde(和Y tilde,其中每一个entity都可以由一个特征向量表示。n和m代表该domain下数据集中的entity数量。

 

 

文中所讨论的范围主要集中于涉及图像和文本的任务,因此此处的实体可以对应于图像中的对象或句子中的单词。图像可以=表示为一组检测到的对象,每个对象都与一个特征向量相关联,而一个句子则可以被一串word embedding表示。在通用场景下,一个深度神经网络fθ会被设计接收以上的X tilde和Y tilde并用来生成当前语境下的数据表示X和Y。这里的fθ可以是很多模型,θ是模型参数。最后监督信号l将会被用来进行参数θ的学习。训练目标可以简化为这个loss函数公式,其中小l是监督信号,特定任务下loss函数的选择和小l都不同。

 

 

然后是动态图谱构建。文章的设定中,每个域的节点集合需要构建一个图,其中每一个节点都是entity,由一个特征向量表示,并且通过计算成对节点的相似性来选择是否添加边。这里手工定义的超参数t作为一个是否添加边的相似度阈值。论文中选取的t等于0.1。

 

动态这个词在这里体现在图上边之间联系的动态性,因为在模型的训练过程中节点的向量表示X和Y都是会因为参数的迭代更新不断变化的,所以边的有无,或者说图构建,都是一个持续变化且逐渐趋于稳定的过程。

 

经过以上步骤把每个domain中的entities表示成一个图,跨域对齐的任务自然被转换成了一个graph matching 的问题。

 

 

该论文的主要idea来源在于,提出了GOT图最优运输算法,将两种类型的最优传输距离应用到graph matching上,使得该通用框架在许多任务上达到了更好的效果。其中采用的两种最优运输距离,分别是针对节点匹配的Wasserstein distance以及针对边匹配的Gromov-Wasserstein distance。后面用WD和GWD进行缩写描述。

 

大家可能不太了解这个距离的概念,但是我相信KL散度可能大部分人都有所耳闻。这是衡量两个分布之间距离的一个指标。上面提到的WD和GWD实际上也是。其中WD就起源于最优运输问题。也叫推土机距离,这个名字非常形象。

 

其把概率分布想象成一堆石头,如何移动一堆石头,通过最小的累积移动距离把它堆成另外一个目标形状,这就是最优运输所关心的问题。先看上图的下半部分:我们可以把两个分布之间的距离看作是最小需要的累积移动距离,它形象解释了如何将离散分布P转换成离散分布Q的过程。而这个过程中第一步p1移动2个方块到p2,然后p2移动2个到p3,然后p3移动一个方块到p4。最后总共的移动数目5就是这两个分布之间的最优运输距离。

 

相对于KL散度等分布评价指标来说,WD具有很明显的优点,比如可以度量离散分布之间的距离,且满足对称性,然后能够很好地反映概率分布的几何特性。几何特性的意思就比如这个p1到p3,并不是直接跳过去,而是必须经过中间的p2。这使得分布之间的空间距离也被考虑进去。

 

然后我们再来看图片上方文章中这个最优运输的公式:其中µ和v是来自不同域中两个离散的分布,π(µ , v)在这里这里表示的意思是所有的联合分布γ (x, y )的集合。上面这个公式的含义是对于每一个可能的联合分布γ,从中采样x,y属于γ,要注意这里的x,y属于两个不同域。然后计算x y之间的距离。后半截代表对于该次采样的联合分布γ下样本对距离的期望。最后在前面加上下界符号,整个公式的含义也就变成了,所有可能的联合分布下这个期望值所能取到的最小值,换句话说,就是两个分布的最短距离,这也就是最终希望得到的 WD 。

 

但是上面这种形式难以求解,所以进一步的,公式可以通过一些转换,化简成为下面那个形式。其中c(xi,yi)的含义也是两个向量的距离,x,y分别来自两个分布,或者说两个domain。这里把找到最优联合分布gama的这样一个问题,转换成为了找到最优传输矩阵T的问题,显然就相对直观一些了。矩阵T具有天然normalize特性,根据后面这个限定,可以得出其中所有元素加起来都是1. 在这里,其中任意一个元素Tij,代表向量ui移动到向量vj所需要的最小代价,也就是两个向量之间的 WD。后面实验部分会有直观展示。

 

最终,不仅得到了两个域总体的距离,还得到了代表了两个域内entity之间的相关系数的副产物T矩阵。

 

 

如果前面的WD理解了,那这里的GWD也就很好理解了。可以看到右图,我们刚才计算的WD是两个域之间的距离,计算域内距离的方法是前面提到的相似性度量。作者希望那剩下的每个域之间的节点pair,或者说边的距离,使用GWD进行度量。和前面的定义形式基本一样,其中L函数是计算边距离的cost function。而这里的T则成为了一个对齐不同图中边的传输方案。

 

然后作者梳理了WD和GWD分别的优势,GWD可以捕捉边的相似性但是无法直接应用到图对齐,因为只考虑边的相似性话,boy和girl这样一个pair的相似性居然和football,basketball pair是一样的。但他们的语义完全不同,所以就说不通。另一方面来说,WD虽然匹配不同图中的节点,但是又不能捕捉边的相似性。这样不同节点表示的重复entity又会被当做一样并且忽略其周围的关系。就比如There is a red book on the blue desk这个句子,并且给定了一个图,里面不同位置的书有不同的颜色。如果无法理解关系,就无法知道这个句子里面的某本书对应的图里面的哪本。

 

所以很自然的,作者提出了下图这样一个结合方案。

 

 

其中最优传输矩阵T是被共享的,因为他同时结合了节点信息和边的信息。最终GOT的公式如上,很自然的,可以转换到右边的cost function。下图就是计算这样一个距离的流程。原始特征x tilde和y tilde同时输入,经过特定的模型主体输出x y同时计算域内的 cost matrix和跨域的cost matrix。然后计算出通用的传输方案T,作用到GWD和WD的计算流程上,最后得到一个GOT的融合距离。这个融合距离,最终会作为一个drop in正则化参数,在反向传播过程中用来监督各种任务训练中的跨域对齐程度,并且更新模型参数θ。

 

 

该论文的亮点在于通用性,也就是在多个不同的,多模态任务中均有效果,而他主要的修改,就是在原有下游任务的基础上,增加了一个任务无关的,衡量跨域图谱对齐程度的正则化项。所以就算是前面的最优传输没有理解也没关系,不影响后面的阅读和整体思路理解,因为这两个数学上距离的概念都不是作者提出来的,他只是将其做了一个融合改进并且作为通用的方法应用到了模型中。

 

实验:

 

首先是视觉语言的多模态任务。其一是图像-文本跨模态检索。这个任务定义是,当给定一个模态(比如图像)的查询时,它的目标是从数据库中以另一个模态(比如句子)检索最相似的样本。这里的关键挑战是如何通过理解跨模式数据的内容,和度量其语义相似性来匹配跨模式数据。早期的方法采用全局表示来表达整个图像和句子进行一个匹配却忽略了局部细节。这些方法在只包含单个对象的简单场景中工作得好,对于涉及复杂自然场景的真实的情况并不令人满意。

 

 

Scan这个方法通过注意力机制把句子中的词和 图像中的不同区域被识别出来的物体映射到一个共同的 embedding space 来预测整张图和一个句子之间的相似性。GOT在这里用来衡量句子和图片这两个域中graph的对齐程度。

 

然后作者证明单独使用的情况下WD比GWD能够取得更好的效果,而结合起来,作为GOT使用,可以达到最好的效果。并且可以看到,最优传输方案这个矩阵T,可视化出来之后,具有更强的解释性和更少的模糊性,就是他每一个对应关系都很清晰不像注意力矩阵那幺模糊和密集。

 

任务2是VQA,使用的是双线性注意力模型BAN。后面的数字可以看做是生成的注意力图数量,可以理解为模型的复杂程度。可以看出  GOT对于简单模型的提升效果的程度是好于复杂模型。(个人感觉VQA上跨域对齐的好处有待商讨)

 

 

然后对于文本生成的任务来说,提到了图像注释。

 

 

以及机器翻译任务,效果都是比原有的要好。

 

 

最后是段落摘要任务,总之就是在一个或多个方法的baseline基础上加入GOT,效果都有提升。

 

 

文章最后以机器翻译实验为基础进行了消融实验,探讨了T矩阵是否在WD和GWD中共享的影响,证明共享T具有更好的效果。同时也测试了超参数λ 取值的影响,最后发现在λ=0.8的时候效果最好,也就是说WD在这个过程中占了更重要的比例。

 

 

总结

 

该文章的主要出发点是跨域对齐在多模态任务中具有很重要的地位,当然也不局限多模态。从结果可以看出域内和域间的关系在对齐任务上都很重要。同时,作者也提到,这是一个可以广泛应用到许多跨模态任务的通用框架,作者这篇文章的重要性很大一部分也体现在他的通用性上,核心idea是加了一个基于对齐程度的正则化项这样一个trick。

 

最后我们可以看出来,kg里面,特别是小的场景kg中,每一个结点和边都是非常重要的,都有其存在的道理,提升模型效果可以考虑从加强其中语义区分和语义的捕捉入手。

 

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注