GraphSAGE(Graph Sample and Aggregate)
# 1 什么是 GraphSAGE?
GraphSAGE 是由 Hamilton 等人提出的一种 感知邻居特征的方法。它的名字全称是: SAmple and aggreGatE:采样 + 聚合与 GCN 最大的不同在于:
模型 | 聚合邻居 |
---|---|
GCN | 聚合所有邻居(必须用整个图) |
GraphSAGE | 对邻居做采样,支持小批量训练(mini-batch) |
# 2 为什么设计 GraphSAGE?
现实图(如社交网络、商品推荐)可能有百万节点、亿级边,GCN 无法一次加载整个图。
GraphSAGE 的目标:
- ✅ 支持 大规模图的训练
- ✅ 支持 Inductive Learning:学好之后可以预测“没见过的新节点”
- ✅ 邻居采样 + 聚合
# 3 GraphSAGE 的核心机制
每个节点的表示通过如下步骤更新:
第 层表示更新公式:
其中:
- :第 层时的节点特征
- :节点 的邻居集合
- :聚合函数
- :激活函数(如 ReLU)
- :可学习的权重矩阵
也就是:“第 层中,节点 的新表示 ,是它当前层特征 和邻居特征 聚合之后,再经过线性变换和激活函数得到的。”
# 3.1 分解每个部分的含义
# :当前层节点 的特征
- 是维度为 的向量,比如说
- 初始层 () 时,它是节点输入特征(如 one-hot、词向量、属性)
# :节点 的邻居集合
- 就是所有和 相连的节点的编号,例如
# :邻居节点的特征
- 拿到所有邻居在第 层的特征向量
- 比如是三个 3 维向量:
# :把自己也加进去
- 与 GCN 不同,GraphSAGE 明确地把自己的特征也参与聚合
# :聚合函数
- 对这组向量执行某种函数,比如:
- 平均(mean)
- 最大池化(max-pool)
- LSTM
- 输出是一个单独的向量,和输入特征维度一致(或定长)
例如:
# :权重矩阵(可学习)
- 是一个线性变换,大小是 ,例如
- 类似普通神经网络的线性层
Linear(in, out)
- 作用是:将聚合后的特征投影到新空间
# :激活函数
- 通常是
ReLU
、ELU
等 - 用于增加非线性能力
# 全部过程示意
对每个节点 :
- 拿到它的邻居集合
- 取出它和邻居的当前特征
- 对它们进行聚合:得到一个“邻里信息”向量
- 经过线性映射(矩阵乘法)
- 经过激活函数,得到新特征
这就是 GraphSAGE 的一层。
# 3.2 举个具体例子(Mean Aggregator)
假设第 0 层时:
- 节点 特征:
- 邻居是
# Step1: 聚合
# Step2: 线性变换
设 ,则
# Step3: 激活
# 4 采样
GraphSAGE 中的 采样(Sampling)机制 是它区别于传统 GCN 的关键创新之一。下面我将从以下方面 详细解释 GraphSAGE 的采样机制:
# 4.1 为什么要采样?(动机)
在标准 GCN 中,每层都需要访问节点的所有邻居。但图中很多节点的度数很高(如社交网络中的“明星用户”),这会导致:
- 计算量快速爆炸(指数级增长)
- 显存不足(因为需要存所有邻居的表示)
- 无法扩展到大规模图(如百万级节点)
GraphSAGE 的解决方案是:
每层仅从邻居中采样固定数量的节点,用于表示传播和聚合。
# 4.2 采样的基本思路
每一层,GraphSAGE 为每个节点:
- 从其邻居集合 中随机采样固定数量的邻居(如 10 个)
- 这些采样的邻居用于执行聚合操作
- 采样是分层执行的,采样深度 = 网络层数(例如 2 层 GNN,需要采样 2 层邻居)
# 4.3 示例:两层 GraphSAGE 采样示意图
假设:
- 我们有一个两层 GraphSAGE 模型
- 每层采样 2 个邻居
- 节点 是目标节点
采样图如下:
L2邻居 (最远)
/ \
u1 u2
| |
v1(v的邻居) v2(v的邻居)
\ /
v(中心节点)(L0)
2
3
4
5
6
7
我们:
- 第 1 层(最邻近的):从 的邻居中采样 2 个邻居(比如 )
- 第 2 层:对 和 也分别采样 2 个邻居(比如 )
结果:构建出一个以 为根的 局部计算子图,即 v → v1/v2 → u1/u2
# 4.4 分层采样流程
设定采样大小为:
- 第一层:
- 第二层:
- 第 K 层:
对于每个目标节点 ,执行如下采样:
# 目标节点
nodes_L0 = {v}
# 第 1 层:采样邻居,
nodes_L1 = SampleNeighbors(nodes_L0, S_1)
# 第 2 层:采样这些邻居的邻居
nodes_L2 = SampleNeighbors(nodes_L1, S_2)
# ...
2
3
4
5
6
7
8
9
10
SampleNeighbors(nodes_L0, S_1):对集合 nodes_L0
中的每个节点,从它的邻居中随机采样 S_1
个邻居。
最终,我们就构建出一个嵌套的“邻居树”,每个层级的邻居都采样固定数量节点。这就限制了计算复杂度为:
# 4.5 常见采样策略 ——Uniform Random Sampling
最常用的:从邻居中均匀随机选择 个节点。
- 简单高效
- 易于实现
- 对模型鲁棒性影响较小
# 定义一个函数,表示从 node 的邻居中采样最多 S 个邻居。
def sample_neighbors(adj_list, node, S):
# 从邻接表中取出该节点的邻居列表。
neighbors = adj_list[node]
# 如果邻居数量小于等于 S,说明邻居不够采,就直接全部返回。
if len(neighbors) <= S:
return neighbors
# 如果邻居数量大于 S,就从中随机采样 S 个邻居,`random.sample(list, k)`:从 `list` 中随机不放回采样 `k` 个元素
return random.sample(neighbors, S)
2
3
4
5
6
7
8
9
# 5 聚合器 AGG 的常见选择:
聚合器 | 说明 |
---|---|
Mean Aggregator | 聚合邻居平均特征 |
Pooling Aggregator | 每个邻居过 MLP 后取 max |
LSTM Aggregator | 用 LSTM 处理邻居序列(不推荐,顺序无意义) |
GCN-style | 邻居 + 自己一起平均后线性变换 |
# 6 GraphSAGE 模型代码实现
我们现在来 用纯 PyTorch + 手动邻居采样 写一个可运行的 GraphSAGE 模型,用于小图的节点分类任务。
# 6.1 任务目标
我们将实现一个完整的流程:
- 手动构造一个图(邻接字典)
- 实现采样函数(多层邻居采样)
- 实现 GraphSAGE 层(mean aggregator)
- 构建两层 GraphSAGE 网络
- 训练模型进行节点分类
# 6.2 第一步:构造图结构和特征标签
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
# 简单图:6个节点,特征维度=2,标签=2类
X = torch.tensor([
[1.0, 0.0], # 0
[0.0, 1.0], # 1
[1.0, 1.0], # 2
[0.0, 0.0], # 3
[0.5, 0.5], # 4
[1.0, 0.5], # 5
])
Y = torch.tensor([0, 1, 0, 1, 0, 1]) # 节点标签(0/1)
# 邻接字典(无向图 + 自环)
adj = {
0: [0, 1, 2],
1: [1, 0, 3],
2: [2, 0, 4],
3: [3, 1, 5],
4: [4, 2, 5],
5: [5, 3, 4],
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 6.3 第二步:实现多层邻居采样函数
def multi_hop_sampling(seed_nodes, num_neighbors, adj):
layers = [seed_nodes]
for fanout in num_neighbors[::-1]: # 从后往前采样
new_nodes = set()
for node in layers[0]:
neighbors = adj[node]
if len(neighbors) > fanout:
sampled = random.sample(neighbors, fanout)
else:
sampled = neighbors
new_nodes.update(sampled)
# 确保当前层采样出的节点包含上一层节点
new_nodes.update(layers[0])
layers.insert(0, list(new_nodes))
return layers # [l0_nodes, l1_nodes, target_nodes]
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
示例:
# 采样目标节点 = [2]
layers = multi_hop_sampling([2], [2, 2], adj)
print("采样层级节点:", layers) # [2跳邻居, 1跳邻居, 自身]
2
3
# 6.4 第三步:GraphSAGE 层(mean aggregator)
class GraphSAGELayer(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim * 2, out_dim)
def forward(self, x_self, x_neighbors):
# 聚合邻居
agg = x_neighbors.mean(dim=1) # [B, F]
h = torch.cat([x_self, agg], dim=1) # [B, 2F]
return F.relu(self.linear(h))
2
3
4
5
6
7
8
9
10
# 6.5 第四步:两层 GraphSAGE 模型
class GraphSAGE(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.sage1 = GraphSAGELayer(in_dim, hidden_dim)
self.sage2 = GraphSAGELayer(hidden_dim, out_dim)
def forward(self, x_all, layers):
l0_nodes, l1_nodes, target_nodes = layers
# 取节点特征
x_l0 = x_all[l0_nodes] # 第0层节点特征
x_l1 = x_all[l1_nodes] # 第1层节点特征
# 构建索引映射字典(方便快速定位)
l0_map = {nid: i for i, nid in enumerate(l0_nodes)}
l1_map = {nid: i for i, nid in enumerate(l1_nodes)}
# 第1层:l0聚合为l1
l1_adj = build_batch_adj(l1_nodes, l0_nodes, adj) # [B1, K]
x_l1_neighbors = x_l0[l1_adj] # [B1, K, in_dim]
x_l1_new = self.sage1(x_l1, x_l1_neighbors) # [B1, hidden_dim]
# 第2层:l1聚合为target
target_adj = build_batch_adj(target_nodes, l1_nodes, adj) # [B2, K]
x_target_neighbors = x_l1_new[target_adj] # [B2, K, hidden_dim]
x_target = x_l1_new[[l1_map[nid] for nid in target_nodes]] # [B2, hidden_dim]
x_out = self.sage2(x_target, x_target_neighbors) # [B2, out_dim]
return x_out
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 6.5.1 邻居索引构建函数(batch 用)
def build_batch_adj(src_nodes, neighbor_nodes, adj_dict):
# 返回 [len(src), num_neighbors] 的索引,用于 x[neighbor_idx]
neighbor_map = {nid: i for i, nid in enumerate(neighbor_nodes)}
indices = []
for nid in src_nodes:
neighbors = adj_dict[nid]
mapped = [neighbor_map[j] for j in neighbors if j in neighbor_map]
while len(mapped) < 2: # 填充
mapped.append(mapped[0])
indices.append(mapped[:2])
return torch.tensor(indices)
2
3
4
5
6
7
8
9
10
11
# 6.6 第五步:训练模型
model = GraphSAGE(2, 4, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
target_nodes = [2, 3, 5] # 训练的目标节点
layers = multi_hop_sampling(target_nodes, [2, 2], adj)
out = model(X, layers) # [len(target_nodes), 2]
loss = loss_fn(out, Y[target_nodes])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
pred = out.argmax(dim=1)
acc = (pred == Y[target_nodes]).float().mean().item()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 7 GraphSAGE 模型代码详解
# 7.1 multi_hop_sampling(...)函数
我们来详细逐行讲解这段 GraphSAGE 中的多层邻居采样函数:
# 7.1.1 函数定义
def multi_hop_sampling(seed_nodes, num_neighbors, adj):
seed_nodes
:最顶层的“目标节点”,也就是我们要预测/训练的节点(例如[2, 3, 5]
)num_neighbors
:每一层要采样的邻居数量,比如[2, 2]
表示:- 第1层采样2个邻居
- 第2层采样2个邻居
adj
:图的邻接表(字典格式),例如:adj = { 0: [0, 1, 2], 1: [1, 0, 3], ... }
1
2
3
4
5
# 7.1.2 初始化节点层列表
layers = [seed_nodes]
layers
:存储从底往上采样的节点列表。- 最开始只有“顶层”节点(要分类的目标节点)。
例如一开始:
layers = [[2, 3, 5]]
# 7.1.3 反向采样多层邻居(从上往下)
for fanout in num_neighbors[::-1]:
[::-1]
是反转,我们从上往下采样。- 比如
num_neighbors = [2, 2]
,它变成[2, 2]
(本例对称其实一样,但结构上是“从上层往下”)。
# 7.1.4 对当前层每个节点采样邻居
new_nodes = set()
for node in layers[0]:
neighbors = adj[node]
sampled = random.sample(neighbors, min(fanout, len(neighbors)))
new_nodes.update(sampled)
2
3
4
5
逐个解释:
layers[0]
是当前最上层的节点集合- 对每个
node
:- 拿出它的邻居
adj[node]
- 随机从中采样
fanout
个邻居 - 加入集合
new_nodes
例子:
- 拿出它的邻居
当前 layers[0] = [2, 3]
fanout = 2
node 2 的邻居:adj[2] = [2, 0, 4]
采样了 [0, 4]
node 3 的邻居:adj[3] = [3, 1, 5]
采样了 [3, 5]
new_nodes = {0, 4, 3, 5}
2
3
4
5
6
7
8
9
10
# 7.1.5 确保当前层包含上一层节点
new_nodes.update(layers[0])
这是关键操作!
- GraphSAGE 要求每一层必须包含上层节点本身(用于自连接)
- 否则后面做
map[nid] for nid in upper_layer
会KeyError
所以我们手动把上一层节点(如 [2, 3]
)放回到当前采样集合中。
# 7.1.6 插入采样结果作为新一层
layers.insert(0, list(new_nodes))
- 将新一层采样的节点插入到最前面(靠底层)
- 继续下一轮采样
最终 layers
结构就是:
[l0_nodes, l1_nodes, target_nodes]
每一层节点都是编号组成的列表。
# 7.1.7 返回采样结果
return layers
返回的结构是一个 list of node_id lists:
[
[0, 1, 4, 5], # l0_nodes(最底层)
[2, 3, 5], # l1_nodes(上一层)
[2, 3, 5] # target_nodes(最顶层)
]
2
3
4
5
# 7.1.8 总结流程图
假设我们要训练节点 [2, 3]
,每层采样 2 个邻居:
multi_hop_sampling([2,3], [2,2], adj)
↓
target_nodes = [2, 3]
↑ ↑
采样 2 个邻居:1-hop → l1_nodes
↑ ↑
再采样 2 个邻居:2-hop → l0_nodes
2
3
4
5
6
7
得到三层:
layers = [l0_nodes, l1_nodes, target_nodes]
这些节点构成一个局部子图,GraphSAGE 就在这个子图上做多层聚合。
← 图注意力网络(GAT) 计算机期刊→