COMP5313 Week 07 Graph Machine Learning 与 GNN 讲课总结

COMP5313 Week 07 讲课总结:Graph Machine Learning 与 GNN

本讲前半段机器学习 / 深度学习背景,老师明确说是 optional、not in assessment。下面只整理图机器学习和 GNN 的可考重点。

1. 本讲主题

从 Week 07 开始进入 machine learning on graphs。核心思想是:

  • 图结构本身是离散的,不方便直接喂给普通机器学习模型。
  • 需要把节点、边、子图或整张图映射到连续向量空间。
  • 如果把节点映射成向量,就是 node embedding
  • 好的 embedding 应该保留图结构:原图中接近或相似的节点,在 embedding space 中也应该接近。

老师用 karate club 图说明:社区结构在图上明显,在二维 embedding 中也能形成相近的 cluster。

2. 图机器学习的典型任务

老师按预测对象层级讲了几类任务。

Node Classification

预测单个节点的类别。

例子:citation network 中,每个节点是一篇论文,边是引用关系;任务是预测论文主题,比如 CS、physics 等。

预测两个节点之间是否应该存在边,或未来是否会形成边。

老师提到这和之前 link formation 的 common-neighbour heuristic 有联系,但这里用机器学习方法。更典型的应用是 knowledge graph completion:预测两个实体之间是否存在某种关系。

Community Detection

把图中的节点分成社区。老师说这通常是 unsupervised learning,因为没有 ground-truth label。

Graph Similarity

衡量两张图有多相似。可以先得到图之间的 similarity matrix,再用传统 clustering 方法。

Graph Classification

预测整张图的类别。

例子:分子可以建模为图,任务是预测某个 molecule 是否 toxic。

3. 传统机器学习处理图的流程

传统流程需要 feature engineering

  1. 从原始图手工构造 feature matrix。
  2. 每一行表示一个节点。
  3. 每一列表示一个 feature。
  4. feature 可以是 PageRank、degree、clustering coefficient 等。
  5. 得到 feature matrix 后,用普通 ML 模型做分类或回归。

问题是:一旦构造完 feature matrix,图结构本身通常就被丢掉了。模型只看节点特征,不再直接看边。

4. 图数据为什么难处理

老师强调 graph data 和 image / sequence data 不一样:

  • 不同图可以有不同数量的节点。
  • 同一张图中,不同节点的 degree 不同。
  • 图没有固定 node ordering。
  • 图没有固定 reference point。
  • CNN / RNN 不能直接套到一般图上。

所以需要专门的 graph neural network。

5. Node Classification 的输入设定

对节点分类任务,输入通常包括:

  • \(G=(V,E)\)
  • 邻接矩阵 \(A\)
  • 节点特征矩阵 \(X \in \mathbb{R}^{n \times q}\)
  • 一部分节点有 ground-truth labels,构成 training set。
  • 其他节点 label 未知,需要预测。
  • 如果有 \(C\) 个类别,label 通常用 one-hot encoding 表示。

这里 \(n\) 是节点数,\(q\) 是每个节点的初始 feature dimension。

6. 两个 naive approaches

方法 1:忽略图结构

只使用节点特征 \(X\),直接套普通 ML 模型。

老师说:如果节点特征本身信息很强,这个方法可能也能表现不错。例如论文文本内容已经足够判断主题。但一般来说,加入 graph structure 会更好。

方法 2:把邻接矩阵和特征拼起来

把 adjacency matrix \(A\) 和 feature matrix \(X\) concatenate,作为模型输入。

问题:

  • 输入层大小至少和节点数 \(n\) 相关。
  • 大图中 \(n\) 可以是百万级,无法扩展。
  • 模型不容易泛化到不同大小的图。

7. GNN 的核心:Neighbourhood Aggregation

GNN 的核心思想是:预测节点 \(u\) 时,不只用 \(u\) 自己的 feature,还用邻居甚至 \(K\)-hop 邻居的信息。

直觉:

  • 1 层 GNN 使用 1-hop neighbours。
  • 2 层 GNN 使用 2-hop neighbours。
  • \(K\) 层 GNN 使用 \(K\)-hop neighbourhood。

计算图可以看成从目标节点出发的 BFS tree。

每一层大致做两件事:

  1. Aggregate:聚合邻居上一层的 representation。
  2. Combine:把聚合来的 message 和节点自己的上一层 representation 结合。

8. GCN:Graph Convolutional Network

GCN 是老师讲的第一个基本 GNN。

核心流程:

  • 第 0 层表示就是输入特征: [ h_u^{(0)}=x_u ]
  • \(k\) 层中,节点 \(u\) 的表示来自邻居和自身上一层表示的加权聚合。
  • 通常给每个节点加 self-loop,这样节点自己的上一层表示也参与更新。
  • 聚合后做 linear transformation,再做 nonlinear activation。

标准矩阵形式可以写成:

[ H^{(k)}=(D^{-1/2}AD{-1/2}H{(k-1)}W^{(k)}) ]

其中 \(\tilde A=A+I\),表示加 self-loop 后的邻接矩阵。

GCN 为什么 scalable

老师重点解释了它为什么比 naive concatenate 方法好:

  • 同一层所有节点共享同一个 weight matrix \(W^{(k)}\)
  • 参数数量不依赖节点总数 \(n\)
  • 每层基本是在边上传播一次,复杂度约为 \(O(|E|)\)
  • 可以泛化到 unseen nodes:新节点加入后,只要构造它的 computation graph,就能用已有参数预测。

9. GraphSAGE

GraphSAGE 是另一个 GNN 架构。区别主要在 aggregation 和 combine 的设计。

老师讲了几种 aggregator:

  • mean aggregator:直接取邻居表示平均。
  • transform 后再 aggregate:先对邻居表示做变换,再 mean / max。
  • LSTM aggregator:随机打乱邻居顺序后用 LSTM 聚合。

GraphSAGE 的 combine 通常是把节点自己的表示和邻居 message concatenate 起来。

10. GIN:Graph Isomorphism Network

GIN 的 aggregation 用 sum

老师强调:sum 比 mean / max 更 expressive。

直觉:

  • mean 可能无法区分“一个红 + 一个蓝”和“两个红 + 两个蓝”这类 multiset。
  • max 只看最大值,也会丢掉数量信息。
  • sum 能保留更多 multiset 信息,所以表达能力更强。

GIN combine 后会接一个 MLP,老师提到即使把 \(\epsilon\) 设为 0,表现也可以很好。

11. Graph Embedding 与 Graph Classification

前面主要是 node embedding:每个节点一个向量。

有些任务需要整张图一个向量,也就是 graph embedding

常见做法很简单:

  1. 先计算每个节点的 embedding。
  2. 对所有节点 embedding 做 sum 或 mean。
  3. 得到整张图的 representation。
  4. 再用普通 ML 模型做 graph classification。

例子:分子图分类,预测 molecule 是否 toxic。

12. 考点重点

  • 知道 graph ML 的目标:把图中的节点 / 边 / 子图 / 整图映射到连续向量空间。
  • 会区分 node classification、link prediction、community detection、graph similarity、graph classification。
  • 知道传统 graph ML 依赖 feature engineering。
  • 知道图数据难点:无固定大小、无固定 ordering、degree 不同。
  • 知道 naive approaches 的问题:忽略图结构;或输入维度依赖 \(n\),无法扩展。
  • 理解 neighbourhood aggregation:\(K\) 层 GNN 看 \(K\)-hop neighbours。
  • 知道 GCN 加 self-loop,并通过邻居聚合更新表示。
  • 知道 GNN 参数共享:参数数量不依赖节点数,训练大图才可行。
  • 知道 GraphSAGE 的 aggregate/combine 思路。
  • 知道 GIN 使用 sum aggregator,sum 比 mean/max 更 expressive。
  • 知道 graph embedding 通常可以由 node embeddings 做 sum/mean 得到。