Press "Enter" to skip to content

Quant进阶:用『最少』的数学,学『最全』的图神经网络

作者 | Rishabh Anand       编译 | QIML编辑部

前言

多年来,图深度学习(GDL)的发展步伐加快了。现实生活中许多网状结构的问题使的GDL成为一个通用的工具。该领域在社交媒体、药物发现、芯片植入、预测、生物信息学等方面显示出了很大的潜力。在本篇文章中, 我们将使用最少的数学知识为大家介绍图深度学习(或图神经网络)。

图深度学习背后的思想是通过节点 和边来 学习图的结构和空间特征,这些节点和边代表不同的主体体及主体间的交互。

往期 图神经网络 文章推荐:

 

点击标题阅读

 

因子挖掘:基于图神经网络与公司主营(附代码)

 

深度动态因子模型、图神经网络

 

基于图神经网络、图谱型数据的收益预测模型(附代码)

 

白鹭女掌门张晨樱:打造反脆弱的量化多策略盈利武器

 

如何表示图?

 

在学习图神经网络之前,我们先弄清楚计算机如何表示一幅图。图 是一种包含一系列节点 和边 的数据结构。如果两个节点 和 相连,则 ,否则 。节点之间的连接关系可以用一个矩阵表示,这个矩阵称为邻接矩阵(Adjacency Matrix):

 

 

我们假设上文所说的图中的边是没有权重也没有方向,并且假设是同质图,也就是说图中的节点或边都是同类型的,反过来如果节点或边有不同的类型,则称之为异质图。

 

图与常规数据的不同之处在于,图具有神经网络必须遵守的结构;如果不好好利用它就太浪费了。这是一个社交媒体图的例子,节点是用户,边是他们的交互(如关注/喜欢/转发)。

 

 

图(Graph)与图像(Images)的关系

 

图像本身就是一个图,它是一种特殊的变体,称为“网格图”,其中一个节点的出边数对所有内部节点和角节点都是恒定的。在图像网格图中存在一些一致的结构,允许对其执行简单的卷积类操作

 

如下左图,一幅图像可以被认为是一个特殊的图,其中每个像素都是一个节点,并通过虚拟的边缘与周围的其他像素相连接。 当然,现实情况下以这样的视角看图像是不现实的,因为这意味着有一个非常庞大的图。以CIFAR-10的图像为例,的图像将有3072个节点和1984条边。更别说像ImageNet里的图像了。

 

 

然而,正如你所观察到的,图并不是那幺完美。不同的节点有不同的度(连接到其他节点的数量),并且分布在各个地方。图中没有固定的结构,但是这种不固定的结构为图增加了价值。因此,在这个图上学习的任何神经网络在学习节点(和边)之间的空间关系时必须尊重这个结构。

 

理解图神经网络

 

单个图神经网络(GNN)层对图中的每个节点上有以下操作:

 

1、消息传递(Message Passing)

 

2、聚合(Aggregation)

 

3、更新(Update)

 

这些操作组合在一起,形成了通过图深度学习的构建模块。GDL的创新主要涉及对这三个步骤的改变。

 

节点(Node)

 

请记住,节点代表一个实体或对象,就像社交媒体中的用户。因此,这个节点具有所表示实体的一系列属性特征。这些节点属性构成了节点的特征(即“节点特征”或“节点嵌入”)。通常,这些特征可以用向量 表示。这个向量要幺是一个隐含的表征(embedding),要幺是这个结点具体的属性特征。

 

例如,在社交媒体图中,用户节点具有可以用数字表示的年龄、性别、政治倾向、关系状态等属性。

 

同样,在分子图中,原子节点可能具有化学性质,如对水、力、能量等的亲和力,这些性质也可以用数字表示。

 

这些节点的特征可以作为GNN的输入。正式的,每个节点 的特征用 表示,并且用 表示标签:

 

 

边(Edges)

 

边也可以有属性或特征 。例如,原子之间的化学键,我们可以把下面的分子想象成一个图形,其中原子是节点,键是边。虽然原子节点本身有各自的特征向量,但边可以有不同的边特征,以编码不同类型的键(单键、双键、三键)。

 

 

现在我们知道了如何表示图中的节点和边,下面让我们从一个简单的图开始,图中有一些节点(具有节点特征)和边。

 

 

消息传递(Message Passing)

 

GNN以其学习结构信息的能力而闻名。通常情况下,具有相似特征或属性的节点是相互连接的(在社交媒体中是这样的)。GNN利用这一特点,学习特定的节点如何以及为什幺会相互连接,而有些节点则不会。为此,GNN会观察节点的邻域。

 

定义节点邻域为与节点连接的一系列节点,表示为

 

 

GNN可以通过查看邻近的节点来了解节点的很多信息。为了在源节点和它的邻居之间共享信息,GNN使用 消息传递 的机制。

 

对于GNN层,Message Passing定义为获取邻居的节点特征,对其进行转换,并将其“传递”给源节点的过程。对图中的所有节点并行地重复这个过程。

 

如下图,节点6的邻域 。我们用 分别表示三个邻节点的特征,并用函数 表示消息转换的过程。那幺,一条“消息”就是从一个邻节点经函数 转换过来的特征。

 

 

聚合(Aggregation)

 

节点节点6有了从邻节点 传来的消息 ,我们需要把这些消息聚合(Aggregation)。有许多方法可以对消息进行聚合,比如:

 

我们用函数 表示聚合函数,最终聚合后的消息如下:

 

更新(Update)

 

使用这些聚合的消息,GNN层现在必须更新源节点 的特性。在这个更新步骤里,节点 不仅应该知道它自己,还应该知道它的邻居。这个过程可以通过获取节点 的特征向量并将其与聚合的消息相结合。

 

比如使用加法对自身的特征和邻节点传递过来的聚合后的消息进行结合:

 

其中是激活函数(比如ReLU,ELU,Tanh),和都是一个简单的MLP层,用于改变向量的维度。

 

当然也可以简单的讲节点本身的特质与消息进行拼接:

 

更通用的,我们用函数表示对节点本身的特质与消息进行转换的过程:

 

注意,上式中为原节点的特征,经过一层GNN的前向传播计算后的结果用表示。如果有多层GNN,则使用表示经第层GNN计算后的结果,也就是说=。

 

所有以上三个过程整合到一起 ,可以用下式表示(使用加法进行消息的聚合):

 

如果边也有特征,用表示边的特征,那幺在第层,我们可以使用一个简单的MLP层来更这一层边的表征:

 

邻接矩阵(Adjacency Matrices)

 

目前为止,我们都是考虑图中某个节点的前向传播过程。如果有邻接矩阵我们就可以在整个图上进行前向传播计算。

 

在传统神经网络中,对于一个样本一个前向传播是:

 

其中,。如果需要对所有样本同时进行前向传播,可以用矩阵的形式:

 

在邻接矩阵中,每一行表示与节点 连接的所有节点,其中如果表示相连,如果表示无连接。比如,表示节点2与节点1,3和4相连。所以当与相乘时,第2和5列就会忽略。

 

 

 

 

 

所以基于矩阵,对图上所有节点进行前向计算时可以使用邻接矩阵:

 

但是邻居矩阵数学公式:中并没有考虑节点自身的信息,所以可以给邻接矩阵加上一个单位矩阵:

 

叠加GNN层

 

上文我们介绍了一层GNN的结构,我们可以叠加多个GNN层构建一个图神经网络模型:

 

1、第一层的输入是节点的原始特征,输出为隐藏状态其中为第一层表征的维度。

 

2、然后作为第二层的输入,输出。

 

3、多层之后,在第层的输出为

 

以上作为模型的超参数,需要我们自行设定。

 

 

有了最后一层的输出数学公式:,我们可以做很多事情:

 

1、我们可以在的维度计算的和,得到一个维度的向量作为整个图的表征,可以使用这个表征对图进行分类。(图分类任务)

 

 

2、我们也可以将传入一个Graph Auytoencodr进行图结构的清洗与重建。

 

 

3、我们还可以把每个节点的表征传入一个分类器,进行节点分类。

 

 

4、把两个节点的表征一起传入一个 MLP 层,去判断两个节点之间是否存在连接。

 

 

常见的图神经网络

 

Graph Convolution Network(GCN)

 

详细介绍参考以下论文:

 

https://arxiv.org/abs/1609.02907

 

GCN中,对于邻节点的消息,只是一个简单的加总,再加上一个非线性激活的过程:

 

但是,在论文中,作者还对邻接矩阵根据每个节点的度进行了标准化,即:

 

最后:

 

 

Graph Attention Network(GAT)

 

详细介绍参考以下论文:

 

https://arxiv.org/abs/1710.10903

 

不同的邻节点的重要性应该是不一样的,GAT考虑到了这一点并引入了自注意力机制。

 

关于自注意力机制可以参考以下论文:

 

https://arxiv.org/abs/1706.03762

 

每条边的权重,使用以下方法进行计算:

 

在GAT中,消息的传递可以看作是一个加权平均的过程,而权重是通过以上自注意力机制并进行Softmax权重归一后计算得出:

 

 

GraphSAGE

 

详细介绍参考以下论文:

 

https://arxiv.org/abs/1706.02216

 

GraphSAGE的全称是Graph SAmple and AggreGatE,主要用于非常大的密集的图计算。该模型在节点的邻域上引入了学习聚合器。与考虑邻居中的所有节点的传统GAT或GCNs不同,GraphSAGE统一地对邻居进行采样,并在它们上使用学习到的聚合器。

 

假设模型一共有层,每一层。假设我们要 聚合K次 ,则需要有K个 聚合函数(aggregator) ,可以认为是N层。每一次聚合,都是把上一层得到的各个node的特征聚合一次,在假设该node自己在上一层的特征,得到该层的特征。如此反复聚合K次,得到该node最后的特征。最下面一层的node特征就是输入的node features。

 

 

Temporal Graph Network

 

详细介绍参考:

 

https://arxiv.org/abs/2006.10637

 

到目前为止所描述的网络都是在静态图上工作的。现实生活中的大多数情况都是在动态图上工作的,其中节点和边在一段时间内被添加、删除或更新。时态图网络(TGN)研究连续时间动态图(CTDG),它可以表示为按时间顺序排序的事件列表。

 

参考论文中将事件分为两种类型:节点级事件和交互事件。节点级事件涉及一个孤立的节点(例如用户更新他们的个人简介),而交互事件涉及两个可能连接或不连接的节点(例如:用户a转发/关注用户B)。

 

TGN通过以下组件提供了一种模块化的CTDG处理方法:

 

1、消息传递函数: 消息在独立节点或交互节点之间传递(对于任何类型的事件)。

 

2、消息聚合函数: 使用GAT的聚合,通过多个时间步来确定“邻居节点”,而不是在给定的时间步上查看“邻居节点”。

 

3、更新: 这个模块根据在一段时间内发生的交互来更新节点的内存。

 

4、时间嵌入: 一种表示节点的方法,同时也捕捉了时间的本质。

 

5、链接预测: 通过某种神经网络对事件中涉及的节点的时间嵌入进行反馈,计算边的概率(即该边缘是否会在未来发生?)当然,在训练过程中,我们知道边的存在所以边标签是1。我们需要训练基于sigmoid的网络来预测这一点。

 

总结

 

当处理具有类似网络结构的问题时,Graph Deep Learning是一个很好的工具。它们很容易理解和实现,当前流行的库包括“PyTorch geometry”、“Spektral”、“Deep Graph Library”,以及最近发布的“TensorFlow-gnn”。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注