COMP5313 Week 07 Graph Machine Learning 与 GNN 讲课总结
COMP5313 Week 07 讲课总结:Graph Machine Learning 与 GNN
本讲前半段机器学习 / 深度学习背景,老师明确说是 optional、not in assessment。下面只整理图机器学习和 GNN 的可考重点。
1. 本讲主题
从 Week 07 开始进入 machine learning on graphs。核心思想是:
- 图结构本身是离散的,不方便直接喂给普通机器学习模型。
- 需要把节点、边、子图或整张图映射到连续向量空间。
- 如果把节点映射成向量,就是 node embedding。
- 好的 embedding 应该保留图结构:原图中接近或相似的节点,在 embedding space 中也应该接近。
老师用 karate club 图说明:社区结构在图上明显,在二维 embedding 中也能形成相近的 cluster。
2. 图机器学习的典型任务
老师按预测对象层级讲了几类任务。
Node Classification
预测单个节点的类别。
例子:citation network 中,每个节点是一篇论文,边是引用关系;任务是预测论文主题,比如 CS、physics 等。
Link Prediction
预测两个节点之间是否应该存在边,或未来是否会形成边。
老师提到这和之前 link formation 的 common-neighbour heuristic 有联系,但这里用机器学习方法。更典型的应用是 knowledge graph completion:预测两个实体之间是否存在某种关系。
Community Detection
把图中的节点分成社区。老师说这通常是 unsupervised learning,因为没有 ground-truth label。
Graph Similarity
衡量两张图有多相似。可以先得到图之间的 similarity matrix,再用传统 clustering 方法。
Graph Classification
预测整张图的类别。
例子:分子可以建模为图,任务是预测某个 molecule 是否 toxic。
3. 传统机器学习处理图的流程
传统流程需要 feature engineering:
- 从原始图手工构造 feature matrix。
- 每一行表示一个节点。
- 每一列表示一个 feature。
- feature 可以是 PageRank、degree、clustering coefficient 等。
- 得到 feature matrix 后,用普通 ML 模型做分类或回归。
问题是:一旦构造完 feature matrix,图结构本身通常就被丢掉了。模型只看节点特征,不再直接看边。
4. 图数据为什么难处理
老师强调 graph data 和 image / sequence data 不一样:
- 不同图可以有不同数量的节点。
- 同一张图中,不同节点的 degree 不同。
- 图没有固定 node ordering。
- 图没有固定 reference point。
- CNN / RNN 不能直接套到一般图上。
所以需要专门的 graph neural network。
5. Node Classification 的输入设定
对节点分类任务,输入通常包括:
- 图 \(G=(V,E)\)。
- 邻接矩阵 \(A\)。
- 节点特征矩阵 \(X \in \mathbb{R}^{n \times q}\)。
- 一部分节点有 ground-truth labels,构成 training set。
- 其他节点 label 未知,需要预测。
- 如果有 \(C\) 个类别,label 通常用 one-hot encoding 表示。
这里 \(n\) 是节点数,\(q\) 是每个节点的初始 feature dimension。
6. 两个 naive approaches
方法 1:忽略图结构
只使用节点特征 \(X\),直接套普通 ML 模型。
老师说:如果节点特征本身信息很强,这个方法可能也能表现不错。例如论文文本内容已经足够判断主题。但一般来说,加入 graph structure 会更好。
方法 2:把邻接矩阵和特征拼起来
把 adjacency matrix \(A\) 和 feature matrix \(X\) concatenate,作为模型输入。
问题:
- 输入层大小至少和节点数 \(n\) 相关。
- 大图中 \(n\) 可以是百万级,无法扩展。
- 模型不容易泛化到不同大小的图。
7. GNN 的核心:Neighbourhood Aggregation
GNN 的核心思想是:预测节点 \(u\) 时,不只用 \(u\) 自己的 feature,还用邻居甚至 \(K\)-hop 邻居的信息。
直觉:
- 1 层 GNN 使用 1-hop neighbours。
- 2 层 GNN 使用 2-hop neighbours。
- \(K\) 层 GNN 使用 \(K\)-hop neighbourhood。
计算图可以看成从目标节点出发的 BFS tree。
每一层大致做两件事:
- Aggregate:聚合邻居上一层的 representation。
- Combine:把聚合来的 message 和节点自己的上一层 representation 结合。
8. GCN:Graph Convolutional Network
GCN 是老师讲的第一个基本 GNN。
核心流程:
- 第 0 层表示就是输入特征: [ h_u^{(0)}=x_u ]
- 第 \(k\) 层中,节点 \(u\) 的表示来自邻居和自身上一层表示的加权聚合。
- 通常给每个节点加 self-loop,这样节点自己的上一层表示也参与更新。
- 聚合后做 linear transformation,再做 nonlinear activation。
标准矩阵形式可以写成:
[ H^{(k)}=(D^{-1/2}AD{-1/2}H{(k-1)}W^{(k)}) ]
其中 \(\tilde A=A+I\),表示加 self-loop 后的邻接矩阵。
GCN 为什么 scalable
老师重点解释了它为什么比 naive concatenate 方法好:
- 同一层所有节点共享同一个 weight matrix \(W^{(k)}\)。
- 参数数量不依赖节点总数 \(n\)。
- 每层基本是在边上传播一次,复杂度约为 \(O(|E|)\)。
- 可以泛化到 unseen nodes:新节点加入后,只要构造它的 computation graph,就能用已有参数预测。
9. GraphSAGE
GraphSAGE 是另一个 GNN 架构。区别主要在 aggregation 和 combine 的设计。
老师讲了几种 aggregator:
- mean aggregator:直接取邻居表示平均。
- transform 后再 aggregate:先对邻居表示做变换,再 mean / max。
- LSTM aggregator:随机打乱邻居顺序后用 LSTM 聚合。
GraphSAGE 的 combine 通常是把节点自己的表示和邻居 message concatenate 起来。
10. GIN:Graph Isomorphism Network
GIN 的 aggregation 用 sum。
老师强调:sum 比 mean / max 更 expressive。
直觉:
- mean 可能无法区分“一个红 + 一个蓝”和“两个红 + 两个蓝”这类 multiset。
- max 只看最大值,也会丢掉数量信息。
- sum 能保留更多 multiset 信息,所以表达能力更强。
GIN combine 后会接一个 MLP,老师提到即使把 \(\epsilon\) 设为 0,表现也可以很好。
11. Graph Embedding 与 Graph Classification
前面主要是 node embedding:每个节点一个向量。
有些任务需要整张图一个向量,也就是 graph embedding。
常见做法很简单:
- 先计算每个节点的 embedding。
- 对所有节点 embedding 做 sum 或 mean。
- 得到整张图的 representation。
- 再用普通 ML 模型做 graph classification。
例子:分子图分类,预测 molecule 是否 toxic。
12. 考点重点
- 知道 graph ML 的目标:把图中的节点 / 边 / 子图 / 整图映射到连续向量空间。
- 会区分 node classification、link prediction、community detection、graph similarity、graph classification。
- 知道传统 graph ML 依赖 feature engineering。
- 知道图数据难点:无固定大小、无固定 ordering、degree 不同。
- 知道 naive approaches 的问题:忽略图结构;或输入维度依赖 \(n\),无法扩展。
- 理解 neighbourhood aggregation:\(K\) 层 GNN 看 \(K\)-hop neighbours。
- 知道 GCN 加 self-loop,并通过邻居聚合更新表示。
- 知道 GNN 参数共享:参数数量不依赖节点数,训练大图才可行。
- 知道 GraphSAGE 的 aggregate/combine 思路。
- 知道 GIN 使用 sum aggregator,sum 比 mean/max 更 expressive。
- 知道 graph embedding 通常可以由 node embeddings 做 sum/mean 得到。