Xz's blog Xz's blog
首页
时间序列
多模态
合成生物学
其他方向
生活
工具相关
PyTorch
导航站

Xu Zhen

首页
时间序列
多模态
合成生物学
其他方向
生活
工具相关
PyTorch
导航站
  • 时序预测

  • 图学习

    • 图学习基础(GCN)
    • 图注意力网络(GAT)
    • GraphSAGE(Graph Sample and Aggregate)
      • 1 什么是 GraphSAGE?
      • 2 为什么设计 GraphSAGE?
      • 3 GraphSAGE 的核心机制
        • 3.1 分解每个部分的含义
        • $h_i^{(l)}$:当前层节点 $i$ 的特征
        • $\mathcal{N}(i)$:节点 $i$ 的邻居集合
        • $\{ h_j^{(l)} | j \in \mathcal{N}(i) \}$:邻居节点的特征
        • $\{ hi^{(l)} \} \cup \{ hj^{(l)} \}$:把自己也加进去
        • $\text{AGG}^{(l)}(\cdots)$:聚合函数
        • $W^{(l)}$:权重矩阵(可学习)
        • $\sigma$:激活函数
        • 全部过程示意
        • 3.2 举个具体例子(Mean Aggregator)
        • Step1: 聚合
        • Step2: 线性变换
        • Step3: 激活
      • 4 采样
        • 4.1 为什么要采样?(动机)
        • 4.2 采样的基本思路
        • 4.3 示例:两层 GraphSAGE 采样示意图
        • 4.4 分层采样流程
        • 4.5 常见采样策略 ——Uniform Random Sampling
      • 5 聚合器 AGG 的常见选择:
      • 6 GraphSAGE 模型代码实现
        • 6.1 任务目标
        • 6.2 第一步:构造图结构和特征标签
        • 6.3 第二步:实现多层邻居采样函数
        • 6.4 第三步:GraphSAGE 层(mean aggregator)
        • 6.5 第四步:两层 GraphSAGE 模型
        • 6.5.1 邻居索引构建函数(batch 用)
        • 6.6 第五步:训练模型
      • 7 GraphSAGE 模型代码详解
        • 7.1 multihopsampling(...)函数
        • 7.1.1 函数定义
        • 7.1.2 初始化节点层列表
        • 7.1.3 反向采样多层邻居(从上往下)
        • 7.1.4 对当前层每个节点采样邻居
        • 7.1.5 确保当前层包含上一层节点
        • 7.1.6 插入采样结果作为新一层
        • 7.1.7 返回采样结果
        • 7.1.8 总结流程图
  • 其他

  • 其他方向
  • 图学习
xuzhen
2025-07-14
目录

GraphSAGE(Graph Sample and Aggregate)

# 1 什么是 GraphSAGE?

GraphSAGE 是由 Hamilton 等人提出的一种 感知邻居特征的方法。它的名字全称是: SAmple and aggreGatE:采样 + 聚合与 GCN 最大的不同在于:

模型 聚合邻居
GCN 聚合所有邻居(必须用整个图)
GraphSAGE 对邻居做采样,支持小批量训练(mini-batch)

# 2 为什么设计 GraphSAGE?

现实图(如社交网络、商品推荐)可能有百万节点、亿级边,GCN 无法一次加载整个图。

GraphSAGE 的目标:

  • ✅ 支持 大规模图的训练
  • ✅ 支持 Inductive Learning:学好之后可以预测“没见过的新节点”
  • ✅ 邻居采样 + 聚合

# 3 GraphSAGE 的核心机制

每个节点的表示通过如下步骤更新:

第 lll 层表示更新公式:

hi(l+1)=σ(W(l)⋅AGG(l)({hi(l)}∪{hj(l),j∈N(i)}))h_i^{(l+1)} = \sigma \left( W^{(l)} \cdot \text{AGG}^{(l)}\left( \{ h_i^{(l)} \} \cup \{ h_j^{(l)}, \, j \in \mathcal{N}(i) \} \right) \right) hi(l+1)​=σ(W(l)⋅AGG(l)({hi(l)​}∪{hj(l)​,j∈N(i)}))

其中:

  • hi(l)h_i^{(l)}hi(l)​:第 lll 层时的节点特征
  • N(i)\mathcal{N}(i)N(i):节点 iii 的邻居集合
  • AGG\text{AGG}AGG:聚合函数
  • σ\sigmaσ:激活函数(如 ReLU)
  • W(l)W^{(l)}W(l):可学习的权重矩阵

也就是:“第 l+1l+1l+1 层中,节点 iii 的新表示 hi(l+1)h_i^{(l+1)}hi(l+1)​,是它当前层特征 hi(l)h_i^{(l)}hi(l)​ 和邻居特征 hj(l)h_j^{(l)}hj(l)​ 聚合之后,再经过线性变换和激活函数得到的。”

# 3.1 分解每个部分的含义

# hi(l)h_i^{(l)}hi(l)​:当前层节点 iii 的特征
  • 是维度为 F(l)F^{(l)}F(l) 的向量,比如说 [0.2,1.3,−0.7][0.2, 1.3, -0.7][0.2,1.3,−0.7]
  • 初始层 (l=0l=0l=0) 时,它是节点输入特征(如 one-hot、词向量、属性)
# N(i)\mathcal{N}(i)N(i):节点 iii 的邻居集合
  • 就是所有和 iii 相连的节点的编号,例如 N(2)={0,3,4}\mathcal{N}(2) = \{0, 3, 4\}N(2)={0,3,4}
# {hj(l)∣j∈N(i)}\{ h_j^{(l)} | j \in \mathcal{N}(i) \}{hj(l)​∣j∈N(i)}:邻居节点的特征
  • 拿到所有邻居在第 lll 层的特征向量
  • 比如是三个 3 维向量:[a1,a2,a3][a_1, a_2, a_3][a1​,a2​,a3​]
# {hi(l)}∪{hj(l)}\{ h_i^{(l)} \} \cup \{ h_j^{(l)} \}{hi(l)​}∪{hj(l)​}:把自己也加进去
  • 与 GCN 不同,GraphSAGE 明确地把自己的特征也参与聚合
# AGG(l)(⋯)\text{AGG}^{(l)}(\cdots)AGG(l)(⋯):聚合函数
  • 对这组向量执行某种函数,比如:
    • 平均(mean)
    • 最大池化(max-pool)
    • LSTM
  • 输出是一个单独的向量,和输入特征维度一致(或定长)

例如:

AGG(l)({[1,0],[0,1],[1,1]})=[0.667,0.667](mean)\text{AGG}^{(l)}\left( \{ [1, 0], [0, 1], [1, 1] \} \right) = [0.667, 0.667] \quad \text{(mean)} AGG(l)({[1,0],[0,1],[1,1]})=[0.667,0.667](mean)

# W(l)W^{(l)}W(l):权重矩阵(可学习)
  • 是一个线性变换,大小是 [F(l)→F(l+1)][F^{(l)} \to F^{(l+1)}][F(l)→F(l+1)],例如 2×42 \times 42×4
  • 类似普通神经网络的线性层 Linear(in, out)
  • 作用是:将聚合后的特征投影到新空间
# σ\sigmaσ:激活函数
  • 通常是 ReLU、ELU 等
  • 用于增加非线性能力
# 全部过程示意

对每个节点 iii:

  1. 拿到它的邻居集合 N(i)\mathcal{N}(i)N(i)
  2. 取出它和邻居的当前特征
  3. 对它们进行聚合:得到一个“邻里信息”向量
  4. 经过线性映射(矩阵乘法)
  5. 经过激活函数,得到新特征 hi(l+1)h_i^{(l+1)}hi(l+1)​

这就是 GraphSAGE 的一层。

# 3.2 举个具体例子(Mean Aggregator)

假设第 0 层时:

  • 节点 iii 特征:hi(0)=[1.0,2.0]h_i^{(0)} = [1.0, 2.0]hi(0)​=[1.0,2.0]
  • 邻居是 j=1,2j = 1, 2j=1,2
    • h1(0)=[0.0,1.0]h_1^{(0)} = [0.0, 1.0]h1(0)​=[0.0,1.0]
    • h2(0)=[1.0,1.0]h_2^{(0)} = [1.0, 1.0]h2(0)​=[1.0,1.0]
# Step1: 聚合

AGG=mean({hi,h1,h2})=mean({[1,2],[0,1],[1,1]})=[0.667,1.33]\text{AGG} = \text{mean}(\{ h_i, h_1, h_2 \}) = \text{mean}(\{ [1,2], [0,1], [1,1] \}) = [0.667, 1.33] AGG=mean({hi​,h1​,h2​})=mean({[1,2],[0,1],[1,1]})=[0.667,1.33]

# Step2: 线性变换

设 W(0)=[1001]W^{(0)} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}W(0)=[10​01​],则

W⋅AGG=[0.667,1.33]W \cdot \text{AGG} = [0.667, 1.33] W⋅AGG=[0.667,1.33]

# Step3: 激活

hi(1)=ReLU([0.667,1.33])=[0.667,1.33]h_i^{(1)} = \text{ReLU}([0.667, 1.33]) = [0.667, 1.33] hi(1)​=ReLU([0.667,1.33])=[0.667,1.33]

# 4 采样

GraphSAGE 中的 采样(Sampling)机制 是它区别于传统 GCN 的关键创新之一。下面我将从以下方面 详细解释 GraphSAGE 的采样机制:

# 4.1 为什么要采样?(动机)

在标准 GCN 中,每层都需要访问节点的所有邻居。但图中很多节点的度数很高(如社交网络中的“明星用户”),这会导致:

  • 计算量快速爆炸(指数级增长)
  • 显存不足(因为需要存所有邻居的表示)
  • 无法扩展到大规模图(如百万级节点)

GraphSAGE 的解决方案是:

每层仅从邻居中采样固定数量的节点,用于表示传播和聚合。

# 4.2 采样的基本思路

每一层,GraphSAGE 为每个节点:

  • 从其邻居集合 N(v)\mathcal{N}(v)N(v) 中随机采样固定数量的邻居(如 10 个)
  • 这些采样的邻居用于执行聚合操作
  • 采样是分层执行的,采样深度 = 网络层数(例如 2 层 GNN,需要采样 2 层邻居)

# 4.3 示例:两层 GraphSAGE 采样示意图

假设:

  • 我们有一个两层 GraphSAGE 模型
  • 每层采样 2 个邻居
  • 节点 vvv 是目标节点

采样图如下:

        L2邻居        (最远)
       /   \
   u1        u2
   |          |
v1(v的邻居)  v2(v的邻居)
   \         /
      v(中心节点)(L0)
1
2
3
4
5
6
7

我们:

  • 第 1 层(最邻近的):从 vvv 的邻居中采样 2 个邻居(比如 v1,v2v1, v2v1,v2)
  • 第 2 层:对 v1v1v1 和 v2v2v2 也分别采样 2 个邻居(比如 u1,u2u1, u2u1,u2)

结果:构建出一个以 vvv 为根的 局部计算子图,即 v → v1/v2 → u1/u2

# 4.4 分层采样流程

设定采样大小为:

  • 第一层:S1S_1S1​
  • 第二层:S2S_2S2​
  • 第 K 层:SKS_KSK​

对于每个目标节点 vvv,执行如下采样:

# 目标节点
nodes_L0 = {v}

# 第 1 层:采样邻居,
nodes_L1 = SampleNeighbors(nodes_L0, S_1)

# 第 2 层:采样这些邻居的邻居
nodes_L2 = SampleNeighbors(nodes_L1, S_2)

# ...
1
2
3
4
5
6
7
8
9
10

SampleNeighbors(nodes_L0, S_1):对集合 nodes_L0 中的每个节点,从它的邻居中随机采样 S_1 个邻居。

最终,我们就构建出一个嵌套的“邻居树”,每个层级的邻居都采样固定数量节点。这就限制了计算复杂度为:

O(SK×SK−1×⋯×S1)O(S_K \times S_{K-1} \times \cdots \times S_1) O(SK​×SK−1​×⋯×S1​)

# 4.5 常见采样策略 ——Uniform Random Sampling

最常用的:从邻居中均匀随机选择 SSS 个节点。

  • 简单高效
  • 易于实现
  • 对模型鲁棒性影响较小
# 定义一个函数,表示从 node 的邻居中采样最多 S 个邻居。
def sample_neighbors(adj_list, node, S):
	# 从邻接表中取出该节点的邻居列表。
    neighbors = adj_list[node]
    # 如果邻居数量小于等于 S,说明邻居不够采,就直接全部返回。
    if len(neighbors) <= S:
        return neighbors
    # 如果邻居数量大于 S,就从中随机采样 S 个邻居,`random.sample(list, k)`:从 `list` 中随机不放回采样 `k` 个元素
    return random.sample(neighbors, S)
1
2
3
4
5
6
7
8
9

# 5 聚合器 AGG 的常见选择:

聚合器 说明
Mean Aggregator 聚合邻居平均特征
Pooling Aggregator 每个邻居过 MLP 后取 max
LSTM Aggregator 用 LSTM 处理邻居序列(不推荐,顺序无意义)
GCN-style 邻居 + 自己一起平均后线性变换

# 6 GraphSAGE 模型代码实现

我们现在来 用纯 PyTorch + 手动邻居采样 写一个可运行的 GraphSAGE 模型,用于小图的节点分类任务。

# 6.1 任务目标

我们将实现一个完整的流程:

  1. 手动构造一个图(邻接字典)
  2. 实现采样函数(多层邻居采样)
  3. 实现 GraphSAGE 层(mean aggregator)
  4. 构建两层 GraphSAGE 网络
  5. 训练模型进行节点分类

# 6.2 第一步:构造图结构和特征标签

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# 简单图:6个节点,特征维度=2,标签=2类
X = torch.tensor([
    [1.0, 0.0],  # 0
    [0.0, 1.0],  # 1
    [1.0, 1.0],  # 2
    [0.0, 0.0],  # 3
    [0.5, 0.5],  # 4
    [1.0, 0.5],  # 5
])

Y = torch.tensor([0, 1, 0, 1, 0, 1])  # 节点标签(0/1)

# 邻接字典(无向图 + 自环)
adj = {
    0: [0, 1, 2],
    1: [1, 0, 3],
    2: [2, 0, 4],
    3: [3, 1, 5],
    4: [4, 2, 5],
    5: [5, 3, 4],
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# 6.3 第二步:实现多层邻居采样函数

def multi_hop_sampling(seed_nodes, num_neighbors, adj):
    layers = [seed_nodes]
    for fanout in num_neighbors[::-1]:  # 从后往前采样
        new_nodes = set()
        for node in layers[0]:
            neighbors = adj[node]
            if len(neighbors) > fanout:
                sampled = random.sample(neighbors, fanout)
            else:
                sampled = neighbors
            new_nodes.update(sampled)
        # 确保当前层采样出的节点包含上一层节点
        new_nodes.update(layers[0])
        layers.insert(0, list(new_nodes))
    return layers  # [l0_nodes, l1_nodes, target_nodes]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

示例:

# 采样目标节点 = [2]
layers = multi_hop_sampling([2], [2, 2], adj)
print("采样层级节点:", layers)  # [2跳邻居, 1跳邻居, 自身]
1
2
3

# 6.4 第三步:GraphSAGE 层(mean aggregator)

class GraphSAGELayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim * 2, out_dim)

    def forward(self, x_self, x_neighbors):
        # 聚合邻居
        agg = x_neighbors.mean(dim=1)  # [B, F]
        h = torch.cat([x_self, agg], dim=1)  # [B, 2F]
        return F.relu(self.linear(h))
1
2
3
4
5
6
7
8
9
10

# 6.5 第四步:两层 GraphSAGE 模型

class GraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.sage1 = GraphSAGELayer(in_dim, hidden_dim)
        self.sage2 = GraphSAGELayer(hidden_dim, out_dim)

    def forward(self, x_all, layers):
        l0_nodes, l1_nodes, target_nodes = layers

        # 取节点特征
        x_l0 = x_all[l0_nodes]         # 第0层节点特征
        x_l1 = x_all[l1_nodes]         # 第1层节点特征

        # 构建索引映射字典(方便快速定位)
        l0_map = {nid: i for i, nid in enumerate(l0_nodes)}
        l1_map = {nid: i for i, nid in enumerate(l1_nodes)}

        # 第1层:l0聚合为l1
        l1_adj = build_batch_adj(l1_nodes, l0_nodes, adj)            # [B1, K]
        x_l1_neighbors = x_l0[l1_adj]                                # [B1, K, in_dim]
        x_l1_new = self.sage1(x_l1, x_l1_neighbors)                  # [B1, hidden_dim]

        # 第2层:l1聚合为target
        target_adj = build_batch_adj(target_nodes, l1_nodes, adj)    # [B2, K]
        x_target_neighbors = x_l1_new[target_adj]                    # [B2, K, hidden_dim]
        x_target = x_l1_new[[l1_map[nid] for nid in target_nodes]]   # [B2, hidden_dim]
        x_out = self.sage2(x_target, x_target_neighbors)             # [B2, out_dim]

        return x_out
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

# 6.5.1 邻居索引构建函数(batch 用)

def build_batch_adj(src_nodes, neighbor_nodes, adj_dict):
    # 返回 [len(src), num_neighbors] 的索引,用于 x[neighbor_idx]
    neighbor_map = {nid: i for i, nid in enumerate(neighbor_nodes)}
    indices = []
    for nid in src_nodes:
        neighbors = adj_dict[nid]
        mapped = [neighbor_map[j] for j in neighbors if j in neighbor_map]
        while len(mapped) < 2:  # 填充
            mapped.append(mapped[0])
        indices.append(mapped[:2])
    return torch.tensor(indices)
1
2
3
4
5
6
7
8
9
10
11

# 6.6 第五步:训练模型

model = GraphSAGE(2, 4, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(50):
    model.train()

    target_nodes = [2, 3, 5]  # 训练的目标节点
    layers = multi_hop_sampling(target_nodes, [2, 2], adj)

    out = model(X, layers)  # [len(target_nodes), 2]
    loss = loss_fn(out, Y[target_nodes])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == Y[target_nodes]).float().mean().item()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

# 7 GraphSAGE 模型代码详解

# 7.1 multi_hop_sampling(...)函数

我们来详细逐行讲解这段 GraphSAGE 中的多层邻居采样函数:

# 7.1.1 函数定义

def multi_hop_sampling(seed_nodes, num_neighbors, adj):
1
  • seed_nodes:最顶层的“目标节点”,也就是我们要预测/训练的节点(例如 [2, 3, 5])
  • num_neighbors:每一层要采样的邻居数量,比如 [2, 2] 表示:
    • 第1层采样2个邻居
    • 第2层采样2个邻居
  • adj:图的邻接表(字典格式),例如:
    adj = {
        0: [0, 1, 2],
        1: [1, 0, 3],
        ...
    }
    
    1
    2
    3
    4
    5

# 7.1.2 初始化节点层列表

layers = [seed_nodes]
1
  • layers:存储从底往上采样的节点列表。
  • 最开始只有“顶层”节点(要分类的目标节点)。

例如一开始:

layers = [[2, 3, 5]]
1

# 7.1.3 反向采样多层邻居(从上往下)

    for fanout in num_neighbors[::-1]:
1
  • [::-1] 是反转,我们从上往下采样。
  • 比如 num_neighbors = [2, 2],它变成 [2, 2](本例对称其实一样,但结构上是“从上层往下”)。

# 7.1.4 对当前层每个节点采样邻居

        new_nodes = set()
        for node in layers[0]:
            neighbors = adj[node]
            sampled = random.sample(neighbors, min(fanout, len(neighbors)))
            new_nodes.update(sampled)
1
2
3
4
5

逐个解释:

  • layers[0] 是当前最上层的节点集合
  • 对每个 node:
    • 拿出它的邻居 adj[node]
    • 随机从中采样 fanout 个邻居
    • 加入集合 new_nodes 例子:
当前 layers[0] = [2, 3]
fanout = 2

node 2 的邻居:adj[2] = [2, 0, 4]
采样了 [0, 4]

node 3 的邻居:adj[3] = [3, 1, 5]
采样了 [3, 5]

new_nodes = {0, 4, 3, 5}
1
2
3
4
5
6
7
8
9
10

# 7.1.5 确保当前层包含上一层节点

        new_nodes.update(layers[0])
1

这是关键操作!

  • GraphSAGE 要求每一层必须包含上层节点本身(用于自连接)
  • 否则后面做 map[nid] for nid in upper_layer 会 KeyError

所以我们手动把上一层节点(如 [2, 3])放回到当前采样集合中。

# 7.1.6 插入采样结果作为新一层

        layers.insert(0, list(new_nodes))
1
  • 将新一层采样的节点插入到最前面(靠底层)
  • 继续下一轮采样

最终 layers 结构就是:

[l0_nodes, l1_nodes, target_nodes]
1

每一层节点都是编号组成的列表。

# 7.1.7 返回采样结果

    return layers
1

返回的结构是一个 list of node_id lists:

[
    [0, 1, 4, 5],  # l0_nodes(最底层)
    [2, 3, 5],     # l1_nodes(上一层)
    [2, 3, 5]      # target_nodes(最顶层)
]
1
2
3
4
5

# 7.1.8 总结流程图

假设我们要训练节点 [2, 3],每层采样 2 个邻居:

multi_hop_sampling([2,3], [2,2], adj)
              ↓
target_nodes = [2, 3]
    ↑         ↑
采样 2 个邻居:1-hop → l1_nodes
    ↑         ↑
再采样 2 个邻居:2-hop → l0_nodes
1
2
3
4
5
6
7

得到三层:

layers = [l0_nodes, l1_nodes, target_nodes]
1

这些节点构成一个局部子图,GraphSAGE 就在这个子图上做多层聚合。

#Graph
上次更新: 2025/07/14, 21:11:30

← 图注意力网络(GAT) 计算机期刊→

最近更新
01
Slice切片
07-26
02
引用与借用
07-26
03
所有权
07-26
更多文章>
Theme by Vdoing | Copyright © 2025-2025 Xu Zhen | 鲁ICP备2025169719号
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式