Press "Enter" to skip to content

图马尔可夫网络:融合统计关系学习与图神经网络

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

|石壮威

 

学校|南开大学硕士

 

研究方向|机器学习、图神经网络

论文标题:

 

GMNN: Graph Markov Neural Networks

 

收录会议:

 

ICML 2019

 

论文地址:

 

https:// arxiv.org/abs/1905.0621 4

 

代码地址:

 

https:// github.com/DeepGraphLea rning/GMNN

 

[1]研究了图上的半监督节点分类问题。在此前的文献中,基于统计关系学习(例如马尔科夫随机场)和图神经网络(例如图卷积网络)的方法都已被广泛应用于这类问题。统计关系学习方法通过对象标签的依赖关系建模条件随机场,而图神经网络则以端到端训练的形式,提升了图学习的效率。在本文中,作者提出图马尔可夫神经网络(Graph Markov Neural Networks ,GMNN)。GMNN以条件随机场建模对象标签的联合分布,用变分EM算法进行有效训练。在E-step中,一个GNN学习用于拟合标签后验分布的表示向量。在M-step中,另一个GNN用于建模标签依赖关系。实验结果表明,GMNN取得了优越的结果。

 

相关工作

 

考虑半监督学习中的一个图 ,其中V是节点的集合,E是节点之间边的集合, 是所有节点特征的集合。已知一部分标签 ,L∈V,我们的任务是预测剩下未知的标签 ,U = V \ L。

 

统计关系学习(statistical relationship learning,SRL)方法以如下方式计算标签的联合概率分布:

ψ是边上的势函数,一般是人工定义的特征函数的线性组合。

 

这种情况下,预测未知标签任务被看做是推断问题,我们还要去计算位置标签的后验分布 ,[2]是一种典型的基于高斯马尔可夫随机场与标签传播的方法。然而由于标签的复杂结构关系,后验十分难求。

 

与SRL相比,GNN忽略掉标签的依赖关系,只关注于节点的特征表示。由于GNN将标签之间视为独立,那幺此情况下标签的联合分布表示为:

通过聚合节点特征预测标签:

GMNN

 

GMNN利用CRF通过对象属性(节点特征)来建模标签之间的联合分布: ,使用伪似然变分 EM算法进行优化。其中,E-step中使用一个GNN来学习节点的特征表示以预测标签属性,M-step中使用另一个GNN来建模标签之间的依赖关系。如图1所示。

作者沿用CRF的预测模型: ,其中 是模型参数,我们要做的是优化这个参数来求已知标签的最大似然: 。由于存在大量的未知标签,直接最大化对数似然很困难,因此我们采用变分推断的方法,用变分分布 近似 ,来最大化对数似然的证据下界(ELBO):

(3)式可以通过变分EM算法[3][4]来优化。在M-step,这等价于优化(4)式。然而,直接优化(4)式是很困难的,因为这是对整个条件随机场进行优化,需要计算 的配分函数(partition function),即(1)式中的分母 。基于 的独立性,我们可以将(4)式转为优化(5)式。

其中NB(n)是节点n的邻居。 (5)式被称为伪似然函数(pseudolikelihood function)。在似然函数(4)式中,某节点的标签与图上的其他所有节点有关;在伪似然函数(5)式中,某节点的标签只与其邻域节点有关;此时,通过最大化伪似然函数求取节点标签,就只需要聚合邻域的信息。

 

(5)式的意义是,聚合邻域的标签信息和特征信息,通过最大化伪似然函数求取节点标签。因为GNN是一个聚合邻域信息并进行消息传递的过程,所以 可以通过一个GNN实现。

接下来讨论 ,由于其独立性,故由平均场理论有:

同理, 可以通过一个GNN实现。

最大化似然函数:

(8)式证明见附录,参考文献[4]中也给出了一个类似的式子的证明过程。在(8)式中,用采样代替求期望:

(10)式中, 是一个进行特征传播的GNN,学习一个从特征到标签的映射, 是一个进行标签传播的GNN,学习一个从已标注节点标签到未标注节点标签的映射。为对GMNN进行训练,我们首先预训练 :用全体节点的特征作为输入,将已标注节点标签作为监督信息,为全体节点学习“伪标签”。优化目标:

接着,将生成的“伪标签”输入 ,训练目标是使得其生成的标签与“伪标签”尽量接近,这就是(5)式的意义。根据(8)(9)式可将(5)式简化为:

最后,将节点特征再次输入 ,训练目标是使得其生成的标签与 生成的标签尽量接近,并将此时 输出的标签作为预测结果。训练目标:

所以:

伪代码如下:

实验与应用

 

GMNN除了被应用于半监督的节点分类问题外,还可以被应用于无监督学习问题和链路预测问题。

 

在无监督学习中,由于没有标签的节点,因此我们改为预测每个节点的邻居节点是哪些。这种“将邻域作为标签”的方法在此前的无监督学习算法(例如DeepWalk[5])中得到广泛应用。

 

在链路预测问题中,使用对偶图(dual graph)[6]将链路预测问题转换为节点分类问题。对偶图的示意图如下:

在半监督节点分类问题上的实验(使用Cora, Citeseer, Pubmed三个节点分类数据集):

在无监督学习问题上的实验:

在链路预测问题上的实验:

在few-shot learning问题上的实验:对于每个数据集,随机抽取每个类下的5个标记节点作为训练数据。GMNN显着优于GCN和GAT。这种改进甚至比半监督学习的情况(即每个类使用20个标记节点进行训练)更大。这一观察结果证明了GMNN的有效性,即使在标记对象非常有限的情况下。

参考文献

 

[1] Meng Qu, Yoshua Bengio, and Jian Tang. GMNN: Graph Markov Neural Networks. In ICML, 2019.

 

[2] Jingdong Wang, Fei Wang, Changshui Zhang, Helen C Shen, and Long Quan. Linear neighborhood propagation and its applications. IEEE Transactions on Pattern Analysis and Machine Intelligence, 31(9):1600–1615, 2009.

 

[3] R. M. Neal and G. E. Hinton. A view of the em algorithm that justifies incremental, sparse, and other variants. In Learning in graphical models, pp. 355–368. Springer, 1998.

 

[4] D. M. Blei, A. Kucukelbir and J.D. McAuliffe. : A Review for Statisticians. Journal of the American Statistical Association, 112(518):859-877, 2017.

 

[5] B, Perozzi, R. Al-Rfou, and S. Skiena, Deepwalk: Online learning of social representations. In KDD, 2014.

 

[6] B. Taskar, M. Wong, P. Abbeel and D. Koller. Link prediction in relational data. In NeurIPS, 2004.

 

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注