Press "Enter" to skip to content

用NTK科学地理解神经网络

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

最近看了一些NTK的文章写一些笔记和自己的想法。关于NTK的理论文章知乎上已经有一些写得很好的了[0],这篇笔记就记一些我关心的应用场景。

 

在这篇笔记当中,我们试图用NTK理解下面几个方面: training, generalization, extrapolation

 

以及我比较关心的两个问题:

 

 

    1. 对于神经网络来说,什幺样的东西是容易学的? 在平常的炼丹当中,我们经常会发现有些数据集很容易,模型训练一会儿就能泛化地很好,有些数据集就很难,那幺如何从数学上刻画简单或者困难?

 

    1. 一些特殊的网络结构会带来怎样的优势? 网络结构总是最近顶会的一个热点,从MLP到CNN到transformer又回归了不一样的MLP,那幺不同的结构会带来怎样的不同?

 

 

如果有啥没有说对的欢迎指出来,因为能力有限,也没有做过NTK相关工作,所以如果有哪里错了请一定指出来免得误导大家。

 

Part I: NTK简介

 

NTK最开始的引入是用来描述无限宽深度神经网络在梯度下降训练过程中演化的核。NTK的作者将神经网络的研究与核方法理论的工具联系起来,探讨神经网络在无限宽极限下的动力学行为。NTK在无限宽极限下趋于一个确定的核(Kernel),而且在梯度下降的训练过程中保持不变,因此无限宽网络的输出层结果的动力学可以用一个常微分方程来表示 [0.1]。

 

我们用 来表示梯度下降的时刻(epoch), 来表示在t时刻参数的值, 表示神经网络的输出, 表示loss function。对于连续时间的梯度下降,我们有

 

神经网络参数的变化可以用下面这个微分方程给出

 

 

对于样本 神经网络的output的变化轨迹,我们有

 

 

对于整体的样本,神经网络的变化轨迹可以写成

 

 

从(1)当中来看, 就是一个kernel嘛,当网络宽度趋于无穷的时候, ,这就是我们以前的kernel methods。其中

 

 

再其中 就是神经网络初始化的分布比如高斯分布。简单地想想,在kernel methods当中,kernel负责提feature,SVM这些负责分类;而在神经网络当中前面的神经网络负责提feature,后面的网络负责分类。有了这样的一个近似,我们就可以借用原来kernel methods里面的一些结果了!

 

Part II: Training, Generalization, Extrapolation

 

Training

 

和gradient flow就差了个 。而且我们知道当 是正定的时候, 始终是一个descent direction。那幺如果我们 是正定的,并且最小的特征值和0之间有个gap,神经网络的收敛速率就和gradient descent一样。对于l2 loss function,我们有

 

 

解这个方程,得

 

 

所以在l2 loss的时候是线性收敛的。正经的结论嘛,需要finite step size和width,可以看这篇文章[0.2]。

 

Generalization

 

有了NTK之后我们可以套用以前对kernel的结论来描述神经网络的generalization [1]。

对于两层MLP和多项式函数,我们有

Extrapolation

 

这个是比较有意思的一个东西,因为在传统的机器学习当中,我们认为数据是i.i.d.分布的,很少有人去探究,最近这个话题是比较火的[3] [4.1] [4.2]。简单来说呢,就是研究training distribution之外神经网络的行为。比如呢,我的训练数据都在norm为1的球上,测试数据的norm都比较大,诸如这种情况。对于NTK with l2 loss function来说,有了一个新的数据,我们的预测结果是

 

 

对于两层激活函数是 的MLP,我们有

 

 

因为这个仅仅是一篇知乎,我们考虑简单一点的情况,训练数据都在 的球上。对于 ,我们可以积分得到

 

 

对于norm很大的测试数据( ),并且 我们可以得到

 

 

再带回 就变成了一个线性函数。所以呀,我们知道两层ReLU的MLP在训练数据外面长得像一个线性函数,画图呢就是这个样子。

这点在文章[3]当中详细的推导与说明,结论长这样

在一些应用当中,数据长得是周期函数的样子,所以我们需要extrapolate一个周期函数,从上面的分析当中,我们知道ReLU只能extrapolate线性函数。所以我们要用一些周期的激活函数[5.1][5.2]。

 

我也积了一下其它激活函数的两层MLP,对于二次激活函数和cos激活函数,NTK长这个样子

 

 

所以嘛,我们猜二次激活函数能够完美extrapolate二次函数,cos并不能完美extrapolate,实验一下extrapolate error大概是这个样子的[3]

两个问题

 

什幺是容易学的?

 

首先我们定义一下什幺叫做“容易”:就是收敛地快并且generalize地好。

 

首先我们来看收敛快的问题,从公式(1)当中我们知道如果 i) loss的landscape好 ii) 的最小特征值比较大的话收敛地快。对于i),感觉大约对应的是smooth label或者smooth loss;对于ii),感觉对应的就是各种各样normalization [6.1] [6.2]。

 

从公式(2)当中我们可以知道如果label 和 的大的特征向量平行的话收敛地快。在MLP里面, 大的特征值比如数据的主成分 ,在CNN里面patch的主成分(深度学习击败了之前各种PCA方法 [7.1][7.2]),在GNN里面就是 的主成分(all we have is low-pass filters [7.3][7.4])。同时,y的低次函数分量会和 的最大特征向量平行,所以会出现先学简单的函数再学困难的函数这种implicit bias(如下图[5.1])。如何利用这种特性呢比如就是knowledge distillation:通过神经网络生成的soft label会bias向 的主成分,如果这个时候呢 的主成分又恰好是signal的部分,我们既可以denoise又可以收敛地更快 [8.1]。如果不幸是noise部分我们就会挂掉orz [8.2]。

泛化能力呢可以从Theorem 4.1当中看出来,label 和 的大的特征向量平行泛化能力比较好,这个就牵扯到神经网络的结构了。

 

一些特殊的网络结构会带来怎样的优势?

 

我们首先接着如果神经网络的inductive bias多了呢,需要学的东西就更简单了,泛化能力也会变强。比如GNN里面包含了动态规划的结构,所以学习动态规划就很轻松[9],如下面两张图所示。

那幺其它方面呢?比如表达能力和收敛性?

 

收敛性嘛,对于适合这种网络结构的label,收敛性是大大变好了的,比如两层deep set的NTK就是

 

 

对于合适的y( ),对应的最小特征值相对于MLP是增加了,所以收敛地更快。

 

表达能力方面嘛,是变菜了的,我们都知道MLP是universal approximator,而GNN只能和WL test一样了[10.1],在NTK当中的表现就是 变得不可逆。然后就有很多工作通过对于每一个node加random feature来增加表达能力[10.2][10.3] (通过对 加noise使它变得可逆)。但是random feature的大的特征向量和 我们想要学的label align地并不好,所以根据上面的理论,这部分不能被anoynmous GNN approximate的部分它收敛地很慢,而且泛化能力比较差。实验嘛emmmm在构造的数据集上收敛很慢(下图1[10.3]),在真实数据集上几乎看不出random feature的优势(下图2[10.4])

所以嘛emmm 结构其实是不可或缺的,在从MLP到CNN再到“MLP”的浪潮当中,最近的MLP还是保留了很多特殊的结构 [11]。

 

参考文献

 

[0]深度学习理论研究之路

 

[0.1] 黄伟:深度学习理论之Neural Tangent Kernel第一讲:介绍和文献总结

 

[0.2] Gradient Descent Provably Optimizes Over-parameterized Neural Networks

 

[1] Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels

 

[2] Fine-Grained Analysis of Optimization and Generalization for Overparameterized Two-Layer Neural Networks

 

[3] How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks

 

[4.1] Invariant Risk Minimization

 

[4.2] Out-of-Distribution Generalization via Risk Extrapolation (REx)

 

[5.1] Neural Networks Fail to Learn Periodic Functions and How to Fix It

 

[5.2] Implicit Neural Representations with Periodic Activation Functions

 

[6.1] Optimization Theory for ReLU Neural Networks Trained with Normalization Layers

 

[6.2] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

 

[7.1] PCANet: A Simple Deep Learning Baseline for Image Classification?

 

[7.2] Online Dictionary Learning for Sparse Coding

 

[7.3] Revisiting Graph Neural Networks: All We Have is Low-Pass Filters

 

[7.4] Optimization of Graph Neural Networks: Implicit Acceleration by Skip Connections and More Depth (section 4.3)

 

[8.1] Distillation ≈ Early Stopping? Harvesting Dark Knowledge Utilizing Anisotropic Information Retrieval For Overparameterized Neural Network

 

[8.2] Noisy Labels Can Induce Good Representations

 

[9] What Can Neural Networks Reason About?

 

[10.1] How Powerful are Graph Neural Networks?

 

[10.2] Random Features Strengthen Graph Neural Network s

 

[10.3] The Surprising Power of Graph Neural Networks with Random Node Initialization

 

[10.4] A Framework For Differentiable Discovery Of Graph Algorithms

 

[11] MLP-Mixer:一个比ViT更简洁的纯MLP架构

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注