Xz's blog Xz's blog
首页
时间序列
多模态
合成生物学
其他方向
生活
工具相关
PyTorch
导航站

Xu Zhen

首页
时间序列
多模态
合成生物学
其他方向
生活
工具相关
PyTorch
导航站
  • 时序预测

  • 图学习

    • 图学习基础(GCN)
    • 图注意力网络(GAT)
      • 1 为什么需要图注意力网络(Graph Attention Network, GAT)?
      • 2 GAT 的核心思想
        • 2.1 引入注意力机制(Attention)
        • 2.2 注意力权重 $\alpha_{ij}$ 如何计算?
        • 2.3 使用多头注意力(multi-head attention)
      • 3 GAT 代码实现(单头):
        • 3.1 测试
      • 4 GAT 代码(单头)详细解析:
        • 4.1 背景复习:我们要计算什么?
        • 4.2 输入说明
        • 4.3 目标
        • 4.4 分解解释
        • 第一步:
        • 第二步:
        • 第三步:
        • 4.5 最终作用
      • 5 多头GAT 代码实现(带拼接/平均模式):
        • 5.1 多头 GAT(Multi-head GAT)的核心思想:
        • 5.2 代码实现:多头 GAT 层(带拼接/平均模式)
        • 5.3 多头封装 GAT 层
        • 5.4 使用示例
    • GraphSAGE(Graph Sample and Aggregate)
  • 其他

  • 其他方向
  • 图学习
xuzhen
2025-07-14
目录

图注意力网络(GAT)

# 1 为什么需要图注意力网络(Graph Attention Network, GAT)?

GCN 的邻居聚合是平均的(或者是归一化加权的),例如:

hi(l+1)=∑j∈N(i)1didjWhj(l)h_i^{(l+1)} = \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{d_i d_j}} W h_j^{(l)} hi(l+1)​=j∈N(i)∑​di​dj​​1​Whj(l)​

✅ 优点:简单高效
❌ 缺点:所有邻居权重固定,不能区分“重要邻居”和“无关邻居”

# 2 GAT 的核心思想

# 2.1 引入注意力机制(Attention)

对于每对邻接节点 iii 和 jjj,GAT 学习一个 注意力系数 αij\alpha_{ij}αij​,表示节点 jjj 对节点 iii 的影响程度。

整体更新公式如下:

hi(l+1)=σ(∑j∈N(i)αijWhj(l))h_i^{(l+1)} = \sigma\left( \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j^{(l)} \right) hi(l+1)​=σ​j∈N(i)∑​αij​Whj(l)​​

其中:

  • WWW:可学习的线性变换矩阵
  • αij\alpha_{ij}αij​:通过注意力机制计算
  • σ\sigmaσ:非线性激活函数(如 ReLU)

# 2.2 注意力权重 αij\alpha_{ij}αij​ 如何计算?

  1. 首先对输入特征进行线性变换:

zi=Whiz_i = W h_i zi​=Whi​

  1. 使用一个小的前馈神经网络(通常是一层 MLP)计算注意力得分:

eij=LeakyReLU(a⊤[zi∣∣zj])e_{ij} = \text{LeakyReLU}(a^\top [z_i || z_j]) eij​=LeakyReLU(a⊤[zi​∣∣zj​])

其中:

  • [zi∣∣zj][z_i || z_j][zi​∣∣zj​] 表示拼接
  • aaa 是可学习的向量
  • LeakyReLU 是激活函数
  1. 对每个节点 iii 的邻居 jjj 进行 softmax 归一化:

αij=exp⁡(eij)∑k∈N(i)exp⁡(eik)\alpha_{ij} = \frac{ \exp(e_{ij}) }{ \sum_{k \in \mathcal{N}(i)} \exp(e_{ik}) } αij​=∑k∈N(i)​exp(eik​)exp(eij​)​

# 2.3 使用多头注意力(multi-head attention)

这增强了表达能力,并提高稳定性。最终聚合结果可能是:

  • 拼接式:

hi(l+1)=∥k=1K∑j∈N(i)αij(k)W(k)hj h_i^{(l+1)} = \mathbin\Vert_{k=1}^K \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j hi(l+1)​=∥k=1K​j∈N(i)∑​αij(k)​W(k)hj​

  • 平均式(用于最后一层):

hi(l+1)=1K∑k=1K∑j∈N(i)αij(k)W(k)hj h_i^{(l+1)} = \frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j hi(l+1)​=K1​k=1∑K​j∈N(i)∑​αij(k)​W(k)hj​

# 3 GAT 代码实现(单头):

根据这个公式实现代码:

αij=softmaxj(LeakyReLU(aT[Whi∣∣Whj]))\alpha_{ij} = \text{softmax}_j\left( \text{LeakyReLU}(a^T [Wh_i || Wh_j]) \right) αij​=softmaxj​(LeakyReLU(aT[Whi​∣∣Whj​]))

hi(l+1)=∑j∈N(i)αij⋅Whjh_i^{(l+1)} = \sum_{j \in \mathcal{N}(i)} \alpha_{ij} \cdot Wh_j hi(l+1)​=j∈N(i)∑​αij​⋅Whj​

import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2):
        super().__init__()
        self.W = nn.Linear(in_features, out_features, bias=False)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(alpha)
        self.dropout = nn.Dropout(dropout)
	# X: 输入特征矩阵 [N, F], adj: 邻接矩阵 [N, N]
    def forward(self, X, adj):
        Wh = self.W(X)  # shape [N, F']
        N = Wh.size(0)

        # 构造所有 Wh_i || Wh_j(拼接),用广播技巧计算 [N, N, 2F']
        Wh_i = Wh.unsqueeze(1).expand(-1, N, -1)  # [N, N, F']
        Wh_j = Wh.unsqueeze(0).expand(N, -1, -1)  # [N, N, F']
        Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1)  # [N, N, 2F']

        # 计算 e_{ij}
        e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1))  # [N, N]

        # 对邻接矩阵中不为1的位置mask为 -inf,防止参与 softmax
        attention = e.masked_fill(adj == 0, float('-inf'))
        attention = F.softmax(attention, dim=1)  # softmax over neighbors j
        attention = self.dropout(attention)

        # 聚合邻居特征
        h_prime = torch.matmul(attention, Wh)  # [N, F']
        return F.elu(h_prime)  # 激活
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

# 3.1 测试

import torch
import torch.nn as nn
import torch.nn.functional as F

# 节点数 N=4,特征维度 Fin=2
X = torch.tensor([
    [1.0, 0.0],  # 节点0
    [0.0, 1.0],  # 节点1
    [1.0, 1.0],  # 节点2
    [0.0, 0.0],  # 节点3
])

# 邻接矩阵(含自环),A_hat[i][j] = 1 表示 j 是 i 的邻居
A_hat = torch.tensor([
    [1, 1, 0, 1],  # 节点0与自己、1、3相连
    [1, 1, 1, 0],  # 节点1与自己、0、2相连
    [0, 1, 1, 1],  # 节点2与自己、1、3相连
    [1, 0, 1, 1],  # 节点3与自己、0、2相连
], dtype=torch.float32)

gat = GATLayer(in_features=2, out_features=4)  # 2 -> 4维
out = gat(X, A_hat)
print("输出节点表示:")
print(out)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

# 4 GAT 代码(单头)详细解析:

逐行详细讲解下面的这段代码,这是 GAT 中 计算注意力权重 eije_{ij}eij​ 的关键部分。

def forward(self, X, adj):
	Wh = self.W(X)  # shape [N, F']
	N = Wh.size(0)

	# 构造所有 Wh_i || Wh_j(拼接),用广播技巧计算 [N, N, 2F']
	Wh_i = Wh.unsqueeze(1).expand(-1, N, -1)  # [N, N, F']
	Wh_j = Wh.unsqueeze(0).expand(N, -1, -1)  # [N, N, F']
	Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1)  # [N, N, 2F']

	# 计算 e_{ij}
	e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1))  # [N, N]
1
2
3
4
5
6
7
8
9
10
11

# 4.1 背景复习:我们要计算什么?

在 GAT 中,对于每一对连接节点 iii 和 jjj,我们要计算:

eij=LeakyReLU(a⊤[Whi∥Whj])e_{ij} = \text{LeakyReLU} \left( a^\top [Wh_i \, \Vert \, Wh_j] \right) eij​=LeakyReLU(a⊤[Whi​∥Whj​])

其中:

  • WhiWh_iWhi​:是节点 iii 经过线性变换后的特征,形状是 [F′][F'][F′]
  • [Whi∥Whj][Wh_i \Vert Wh_j][Whi​∥Whj​]:是将 iii 和 jjj 的特征拼接成 [2F′][2F'][2F′]
  • a⊤a^\topa⊤:是一个可学习的向量,形状 [2F′][2F'][2F′],用于投影到一个标量(注意力分数)

我们希望一次性计算所有 (i,j)(i, j)(i,j) 对的 eije_{ij}eij​,得到一个 [N,N][N, N][N,N] 的矩阵。

# 4.2 输入说明

Wh = self.W(X)  # [N, F'],节点特征经过线性变换后的结果
N = Wh.size(0)  # 节点个数
1
2

# 4.3 目标

构造一个形状为 [N, N, 2F'] 的张量 Wh_cat,它的第 (i,j)(i, j)(i,j) 个元素是:

[Whi∥Whj](拼接后的2F维向量)[Wh_i \Vert Wh_j] \quad \text{(拼接后的2F维向量)} [Whi​∥Whj​](拼接后的2F维向量)

# 4.4 分解解释

# 第一步:
Wh_i = Wh.unsqueeze(1).expand(-1, N, -1)  # [N, 1, F'] -> [N, N, F']
1

解释:

  • Wh.unsqueeze(1) 把 Wh 的维度从 [N, F'] 扩展成 [N, 1, F']
  • expand(-1, N, -1) 把它复制成 [N, N, F']
  • 每一行变成了:第 iii 行是 WhiWh_iWhi​ 的拷贝,与所有 jjj 配对 例如:如果原来是
Wh = [[1, 2], [3, 4], [5, 6]]  # N=3
1

那么 Wh_i 会变成:

[
  [[1, 2], [1, 2], [1, 2]],  # 第0行与(0,0), (0,1), (0,2)配对
  [[3, 4], [3, 4], [3, 4]],  # 第1行与(1,0), (1,1), (1,2)配对
  [[5, 6], [5, 6], [5, 6]]
]
1
2
3
4
5
# 第二步:
Wh_j = Wh.unsqueeze(0).expand(N, -1, -1)  # [1, N, F'] -> [N, N, F']
1

解释:

  • Wh.unsqueeze(0) 把它变成 [1, N, F']
  • expand(N, -1, -1) 复制成 [N, N, F']
  • 每一列变成了:第 jjj 列是 WhjWh_jWhj​ 的拷贝,与所有 iii 配对 继续上例:
[
  [[1, 2], [3, 4], [5, 6]],  # 第0列表示 (0,0), (0,1), (0,2)
  [[1, 2], [3, 4], [5, 6]],  # 第1列表示 (1,0), (1,1), (1,2)
  [[1, 2], [3, 4], [5, 6]]
]
1
2
3
4
5
# 第三步:
Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1)  # [N, N, 2F']
1

解释:

  • 把 [Wh_i, Wh_j] 在最后一维拼接起来
  • 得到每个 (i,j)(i,j)(i,j) 对的拼接向量 [Whi∥Whj][Wh_i \Vert Wh_j][Whi​∥Whj​] 所以最终我们得到一个形状为 [N, N, 2F'] 的张量 Wh_cat,其中每一对 (i,j)(i,j)(i,j) 都包含了节点 iii 和 jjj 的特征拼接。
[
  [[1, 2, 1, 2], [1, 2, 3, 4], [1, 2, 5, 6]],  # 第0行与(0,0), (0,1), (0,2)配对
  [[3, 4, 1, 2], [3, 4, 3, 4], [3, 4, 5, 6]],  # 第1行与(1,0), (1,1), (1,2)配对
  [[5, 6, 1, 2], [5, 6, 3, 4], [5, 6, 5, 6]]
]
1
2
3
4
5

# 4.5 最终作用

这个 Wh_cat 会被拿去跟参数 a 做点积:

e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1))  # [N, N]
1

就得到了每一对 (i,j)(i, j)(i,j) 的注意力得分 eije_{ij}eij​。

# 5 多头GAT 代码实现(带拼接/平均模式):

# 5.1 多头 GAT(Multi-head GAT)的核心思想:

  • 多个 GAT heads 并行计算,捕捉不同的“邻接注意力模式”
  • 有两种合并方式:
    1. 拼接(concatenation):用于中间层

hiout=∥k=1Khi(k) h_i^{\text{out}} = \big\Vert_{k=1}^{K} h_i^{(k)} hiout​=​k=1K​hi(k)​

  1. 平均(mean):用于最后输出层

hiout=1K∑k=1Khi(k) h_i^{\text{out}} = \frac{1}{K} \sum_{k=1}^{K} h_i^{(k)} hiout​=K1​k=1∑K​hi(k)​

# 5.2 代码实现:多头 GAT 层(带拼接/平均模式)

import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadGATLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2):
        super().__init__()
        self.W = nn.Linear(in_features, out_features, bias=False)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(alpha)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, adj):
        Wh = self.W(X)  # [N, F']
        N = Wh.size(0)

        Wh_i = Wh.unsqueeze(1).expand(-1, N, -1)
        Wh_j = Wh.unsqueeze(0).expand(N, -1, -1)
        Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1)  # [N, N, 2F']

        e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1))  # [N, N]
        attention = e.masked_fill(adj == 0, float('-inf'))
        attention = F.softmax(attention, dim=1)
        attention = self.dropout(attention)

        h_prime = torch.matmul(attention, Wh)  # [N, F']
        return h_prime  # 注意:不加激活,留给外层决定
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

# 5.3 多头封装 GAT 层

class MultiHeadGATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads=4, dropout=0.0, alpha=0.2, concat=True):
        super().__init__()
        self.heads = nn.ModuleList([
            SingleHeadGATLayer(in_features, out_features, dropout, alpha)
            for _ in range(num_heads)
        ])
        self.concat = concat  # True: 拼接;False: 平均

    def forward(self, X, adj):
        out = [head(X, adj) for head in self.heads]  # list of [N, F']
        if self.concat:
            return F.elu(torch.cat(out, dim=1))  # [N, num_heads * F']
        else:
            return F.elu(torch.mean(torch.stack(out), dim=0))  # [N, F']
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

# 5.4 使用示例

# 输入特征:4个节点,每个2维
X = torch.tensor([
    [1.0, 0.0],
    [0.0, 1.0],
    [1.0, 1.0],
    [0.0, 0.0],
])

# 邻接矩阵(含自环)
A_hat = torch.tensor([
    [1, 1, 0, 1],
    [1, 1, 1, 0],
    [0, 1, 1, 1],
    [1, 0, 1, 1],
], dtype=torch.float32)

# 多头 GAT:4 个头,每个输出维度为 3,拼接成 [N, 12]
gat_multi = MultiHeadGATLayer(in_features=2, out_features=3, num_heads=4, concat=True)
output = gat_multi(X, A_hat)

print("多头 GAT 输出形状:", output.shape)
print(output)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#Graph
上次更新: 2025/07/15, 15:17:12

← 图学习基础(GCN) GraphSAGE(Graph Sample and Aggregate)→

最近更新
01
Slice切片
07-26
02
引用与借用
07-26
03
所有权
07-26
更多文章>
Theme by Vdoing | Copyright © 2025-2025 Xu Zhen | 鲁ICP备2025169719号
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式