深蓝学院-GNN课程笔记06-图神经网络的可扩展性
这个系列是深蓝学院图深度学习理论与实践的同步课程笔记。课程内容涉及图深度学习的基础理论与应用,本节主要介绍图神经网络的可扩展性。
可扩展性
现实生活中的图神经网络往往会面临可拓展性上的问题。目前大规模的社交网络已经达到了上亿的用户级别,而直接在这种超大规模的图上使用GNN是非常困难的。
以GCN和节点分类任务为例,我们可以把单层的图神经网络表示为:
\[f_\text{GCN} (A, F; \Theta) = \hat{A} F \Theta\]假设图神经网络一共有\(L\)层,则训练损失函数可以表示为:
\[L_\text{train} = \sum_{v_i \in V} l(f_\text{GCN} (A, F; \Theta)_i, y_i)\] \[F^{(l)} = \hat{A} F^{(l-1)} \Theta^{(l-1)}\]如果我们直接在全图上进行训练,存储所有层的特征所需要的空间复杂度为\(O(L \cdot \vert V \vert \cdot d)\),其中\(d\)为每个节点的特征维数。显然对于超大规模的图光是存储特征的空间复杂度已经是难以接受的,更不用说进行神经网络前向和反向传播所需要的计算复杂度了。
因此要解决超大规模图上的训练问题一种直观的想法是采用批量训练的方式,每次只在图上选择一部分节点来进行训练。但这种方式也存在邻域爆炸的问题:假设图上每个节点的度平均值为\(\text{deg}\),第\(L\)层的节点需要与它相邻的\(\text{deg}\)个节点的信息;再往下递推,第\(L-1\)层需要\(\text{deg}^2\)个节点的信息…每一层的前向计算过程可以表示为:
\[F_i^{(l)} = \sum_{v_j \in \hat{N}(v_i)} \hat{A}_{i,j} F_j^{(l-1)} \Theta^{(l-1)}\]因此使用批量训练的方式所需要的空间复杂度为\(O(\text{deg}^L \cdot d)\),对于超大规模的图而言仍然是不可接受的。
逐点采样法
为了解决邻域爆炸的问题人们设计了各种邻域采样方法,其中最基本的方法是逐点采样法。在逐点采样法中每次采样时都只从邻域中采样出\(k\)个节点,然后用这\(k\)个节点来估计使用所有邻域情况下的节点特征:
\[F_i^{(l)} = \frac{1}{\vert \hat{N} (v_i) \vert} \sum_{v_j \in \hat{N}(v_i)} \hat{A}_{i,j} F_j^{(l-1)} \Theta^{(l-1)}\]显然使用逐点采样法的空间复杂度为\(O(k^L \cdot d)\),当\(k\)取一个比较小的值时能够有效地控制模型的内存需求。但逐点采样的缺陷也与\(k\)有关,如果想要获得比较好的估计结果就必须增大\(k\)的值,当\(k\)接近节点的度时逐点采样的空间复杂度仍是非常巨大的。
逐层采样法
在逐层采样法中每一层从全图上采样出\(k\)个节点,然后在前向传播时只考虑相邻两层的\(2k\)个节点的连接关系进行计算:
逐层采样法的空间复杂度为\(O(kL \cdot d)\),与层数线性相关因此极大地降低了训练所需的复杂度。显然逐层采样的核心在于如何在每一层中对节点进行采样,在FastGCN中根据最优分布的形式提出了每个节点采样的概率为:
\[q(u) = \frac{\Vert \hat{A}(:, u) \Vert}{\sum_{u' \in V} \Vert \hat{A}(:, u') \Vert}\]此时每一层节点采样的概率是相同的,仅与邻接矩阵有关。使用逐层采样法时需要注意的是可能会出现相邻层的节点没有连接的情况,此时无法更新这些没有相连节点的特征。
子图采样法
除了逐层采样外子图采样法也是一种常用的图采样方法。在子图采样法中我们只对节点进行一次采样,然后在这个子图上对图神经网络进行训练。因此子图采样法可以避免逐层采样中相邻层节点没有连接的问题,同时训练也更高效。
显然子图采样的核心是如何从全图上选择若干个节点作为子图。在Cluster-GCN中使用了对全图进行聚类的方式来获得子图,然后每次选择一个子图进行训练。但这种做法存在一些问题:比如说聚类后的节点往往具有相似的特征和标签,这对模型训练是不利的;同时不同子图间训练时的梯度方程往往比较大,这会影响训练的稳定性和收敛速度。为了克服这些缺陷,Cluster-GCN在聚类时会首先生成较多的小规模子图,然后挑选出几个小的子图并融合成一个稍大的子图来进行训练,这样使得训练的过程更加稳定。
除了在原图上进行聚类外,在GraphSAINT中选择直接对节点进行采样来生成子图。这样在每一轮都会生成一个随机子图,然后在这个随机子图上训练模型。
生成子图可以有不同的策略,比如对节点进行采样生成子图、对边进行采样生成子图、或者通过随机游走来生成子图等。