GCN 是一种在图中结合拓扑结构和顶点属性信息学习顶点的 embedding 表示的方法。然而 GCN 要求在一个确定的图中去学习顶点的 embedding,无法直接泛化到在训练过程没有出现过的顶点,即属于一种直推式 ( transductive ) 的学习。
本文介绍的 GraphSAGE 则是一种能够利用顶点的属性信息高效产生未知顶点 embedding 的一种归纳式 ( inductive ) 学习的框架。 其核心思想是通过学习一个对邻居顶点进行聚合表示的函数来产生目标顶点的 embedding 向量。
▌ GraphSAGE 算法原理
GraphSAGE 是 Graph SAmple and aggreGatE 的缩写,其运行流程如上图所示,可以分为三个步骤:
1. 对图中每个顶点邻居顶点进行采样
2. 根据聚合函数聚合邻居顶点蕴含的信息
3. 得到图中各顶点的向量表示供下游任务使用
▌ 采样邻居顶点
出于对计算效率的考虑,对每个顶点采样一定数量的邻居顶点作为待聚合信息的顶点。设采样数量为 k,若顶点邻居数少于 k,则采用有放回的抽样方法,直到采样出 k 个顶点。若顶点邻居数大于 k,则采用无放回的抽样。
当然,若不考虑计算效率,我们完全可以对每个顶点利用其所有的邻居顶点进行信息聚合,这样是信息无损的。
▌ 生成向量的伪代码
这里 K 是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,如 K=2 的时候每个顶点可以最多根据其2跳邻接点的信息学习其自身的 embedding 表示。
在每一层的循环 k 中,对每个顶点 v,首先使用 v 的邻接点的 k-1 层的 embedding 表示 来产生其邻居顶点的第 k 层聚合表示
, 之后将
和顶点 v 的第 k-1 层表示
进行拼接,经过一个非线性变换产生顶点 v 的第 k 层 embedding 表示
。
▌ 聚合函数的选取
由于在图中顶点的邻居是天然无序的,所以我们希望构造出的聚合函数是对称的(即改变输入的顺序,函数的输出结果不变),同时具有较高的表达能力。
MEAN aggregator
上式对应于伪代码中的第4-5行,直接产生顶点的向量表示,而不是邻居顶点的向量表示。mean aggregator 将目标顶点和邻居顶点的第 k-1 层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第 k 层表示向量。
Pooling aggregator
Pooling aggregator 先对目标顶点的邻接点表示向量进行一次非线性变换,之后进行一次 pooling 操作 ( maxpooling or meanpooling ),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第 k 层表示向量。
LSTM aggregator
LSTM 相比简单的求平均操作具有更强的表达能力,然而由于 LSTM 函数不是关于输入对称的,所以在使用时需要对顶点的邻居进行一次乱序操作。
▌ 参数的学习
在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。
无监督学习形式
基于图的损失函数希望临近的顶点具有相似的向量表示,同时让分离的顶点的表示尽可能区分。目标函数如下:
其中 v 是通过固定长度的随机游走出现在u附近的顶点, 是负采样的概率分布, Q 是负样本的数量。
与 DeepWalk 不同的是,这里的顶点表示向量是通过聚合顶点的邻接点特征产生的,而不是简单的进行一个 embedding lookup 操作得到。
监督学习形式
监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。
▌ GraphSAGE 的实现
这里以 MEAN aggregator 简单讲下聚合函数的实现:
features, node, neighbours = inputs node_feat = tf.nn.embedding_lookup(features, node) neigh_feat = tf.nn.embedding_lookup(features, neighbours) concat_feat = tf.concat([neigh_feat, node_feat], axis=1) concat_mean = tf.reduce_mean(concat_feat,axis=1,keep_dims=False) output = tf.matmul(concat_mean, self.neigh_weights) if self.use_bias: output += self.bias if self.activation: output = self.activation(output)
对于第 k 层的aggregator, features
为第 k-1 层所有顶点的向量表示矩阵, node
和 neighbours
分别为第 k 层采样得到的顶点集合及其对应的邻接点集合。
首先通过 embedding_lookup
操作获取得到顶点和邻接点的第 k-1 层的向量表示。然后通过 concat
将他们拼接成一个
(batch_size,1+neighbour_size,embeding_size)
的张量,使用 reduce_mean
对每个维度求均值得到一个
(batch_size,embedding_size)
的张量。
最后经过一次非线性变换得到 output
,即所有顶点的第 k 层的表示向量。
下面是完整的 GraphSAGE 方法的代码
def GraphSAGE(feature_dim, neighbor_num, n_hidden, n_classes, use_bias=True, activation=tf.nn.relu, aggregator_type='mean', dropout_rate=0.0, l2_reg=0): features = Input(shape=(feature_dim,)) node_input = Input(shape=(1,), dtype=tf.int32) neighbor_input = [Input(shape=(l,),dtype=tf.int32) for l in neighbor_num] if aggregator_type == 'mean': aggregator = MeanAggregator else: aggregator = PoolingAggregator h = features for i in range(0, len(neighbor_num)): if i > 0: feature_dim = n_hidden if i == len(neighbor_num) - 1: activation = tf.nn.softmax n_hidden = n_classes h = aggregator(units=n_hidden, input_dim=feature_dim, activation=activation, l2_reg=l2_reg, use_bias=use_bias, dropout_rate=dropout_rate, neigh_max=neighbor_num[i])( [h, node_input,neighbor_input[i]])# output = h input_list = [features, node_input] + neighbor_input model = Model(input_list, outputs=output) return model
其中 feature_dim
表示顶点属性特征向量的维度, neighbor_num
是一个 list
表示每一层抽样的邻居顶点的数量, n_hidden
为聚合函数内部非线性变换时的参数矩阵的维度, n_classes
表示预测的类别的数量, aggregator_type
为使用的聚合函数的类别。
▌ GraphSAGE 应用
本例中的训练,评测和可视化的完整代码在下面的 git 仓库中:
https://github.com/shenweichen/GraphNeuralNetwork
这里我们使用引文网络数据集 Cora 进行测试,Cora 数据集包含2708个顶点, 5429条边,每个顶点包含1433个特征,共有7个类别。
按照论文的设置,从每个类别中选取20个共140个顶点作为训练,500个顶点作为验证集合,1000个顶点作为测试集。采样时第1层采样10个邻居,第2层采样25个邻居。
节点分类任务结果
通过多次运行准确率在0.80-0.82之间。
节点向量可视化
▌ 参考资料
Hamilton W, Ying Z, Leskovec J. Inductive representation learning on large graphs[C]//Advances in Neural Information Processing Systems. 2017: 1024-1034.
https://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf
PS: 淘系技术部商业机器智能算法团队持续招聘中,方向包括机器学习、NLP、视觉算法、3D 建模、端智能等,期待各路大牛加入(社招/校招都可以)~简历请发邮箱:
[email protected],(*^▽^*)
嘉宾介绍
沈伟臣,阿里巴巴算法工程师,硕士毕业于浙江大学计算机学院。对机器学习,强化学习技术及其在推荐系统领域内的应用具有浓厚兴趣。
Be First to Comment