图注意力网络(GAT)
# 1 为什么需要图注意力网络(Graph Attention Network, GAT)?
GCN 的邻居聚合是平均的(或者是归一化加权的),例如:
✅ 优点:简单高效
❌ 缺点:所有邻居权重固定,不能区分“重要邻居”和“无关邻居”
# 2 GAT 的核心思想
# 2.1 引入注意力机制(Attention)
对于每对邻接节点 和 ,GAT 学习一个 注意力系数 ,表示节点 对节点 的影响程度。
整体更新公式如下:
其中:
- :可学习的线性变换矩阵
- :通过注意力机制计算
- :非线性激活函数(如 ReLU)
# 2.2 注意力权重 如何计算?
- 首先对输入特征进行线性变换:
- 使用一个小的前馈神经网络(通常是一层 MLP)计算注意力得分:
其中:
- 表示拼接
- 是可学习的向量
- LeakyReLU 是激活函数
- 对每个节点 的邻居 进行 softmax 归一化:
# 2.3 使用多头注意力(multi-head attention)
这增强了表达能力,并提高稳定性。最终聚合结果可能是:
- 拼接式:
- 平均式(用于最后一层):
# 3 GAT 代码实现(单头):
根据这个公式实现代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2):
super().__init__()
self.W = nn.Linear(in_features, out_features, bias=False)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(alpha)
self.dropout = nn.Dropout(dropout)
# X: 输入特征矩阵 [N, F], adj: 邻接矩阵 [N, N]
def forward(self, X, adj):
Wh = self.W(X) # shape [N, F']
N = Wh.size(0)
# 构造所有 Wh_i || Wh_j(拼接),用广播技巧计算 [N, N, 2F']
Wh_i = Wh.unsqueeze(1).expand(-1, N, -1) # [N, N, F']
Wh_j = Wh.unsqueeze(0).expand(N, -1, -1) # [N, N, F']
Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1) # [N, N, 2F']
# 计算 e_{ij}
e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1)) # [N, N]
# 对邻接矩阵中不为1的位置mask为 -inf,防止参与 softmax
attention = e.masked_fill(adj == 0, float('-inf'))
attention = F.softmax(attention, dim=1) # softmax over neighbors j
attention = self.dropout(attention)
# 聚合邻居特征
h_prime = torch.matmul(attention, Wh) # [N, F']
return F.elu(h_prime) # 激活
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 3.1 测试
import torch
import torch.nn as nn
import torch.nn.functional as F
# 节点数 N=4,特征维度 Fin=2
X = torch.tensor([
[1.0, 0.0], # 节点0
[0.0, 1.0], # 节点1
[1.0, 1.0], # 节点2
[0.0, 0.0], # 节点3
])
# 邻接矩阵(含自环),A_hat[i][j] = 1 表示 j 是 i 的邻居
A_hat = torch.tensor([
[1, 1, 0, 1], # 节点0与自己、1、3相连
[1, 1, 1, 0], # 节点1与自己、0、2相连
[0, 1, 1, 1], # 节点2与自己、1、3相连
[1, 0, 1, 1], # 节点3与自己、0、2相连
], dtype=torch.float32)
gat = GATLayer(in_features=2, out_features=4) # 2 -> 4维
out = gat(X, A_hat)
print("输出节点表示:")
print(out)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 4 GAT 代码(单头)详细解析:
逐行详细讲解下面的这段代码,这是 GAT 中 计算注意力权重 的关键部分。
def forward(self, X, adj):
Wh = self.W(X) # shape [N, F']
N = Wh.size(0)
# 构造所有 Wh_i || Wh_j(拼接),用广播技巧计算 [N, N, 2F']
Wh_i = Wh.unsqueeze(1).expand(-1, N, -1) # [N, N, F']
Wh_j = Wh.unsqueeze(0).expand(N, -1, -1) # [N, N, F']
Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1) # [N, N, 2F']
# 计算 e_{ij}
e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1)) # [N, N]
1
2
3
4
5
6
7
8
9
10
11
2
3
4
5
6
7
8
9
10
11
# 4.1 背景复习:我们要计算什么?
在 GAT 中,对于每一对连接节点 和 ,我们要计算:
其中:
- :是节点 经过线性变换后的特征,形状是
- :是将 和 的特征拼接成
- :是一个可学习的向量,形状 ,用于投影到一个标量(注意力分数)
我们希望一次性计算所有 对的 ,得到一个 的矩阵。
# 4.2 输入说明
Wh = self.W(X) # [N, F'],节点特征经过线性变换后的结果
N = Wh.size(0) # 节点个数
1
2
2
# 4.3 目标
构造一个形状为 [N, N, 2F']
的张量 Wh_cat
,它的第 个元素是:
# 4.4 分解解释
# 第一步:
Wh_i = Wh.unsqueeze(1).expand(-1, N, -1) # [N, 1, F'] -> [N, N, F']
1
解释:
Wh.unsqueeze(1)
把 Wh 的维度从[N, F']
扩展成[N, 1, F']
expand(-1, N, -1)
把它复制成[N, N, F']
- 每一行变成了:第 行是 的拷贝,与所有 配对 例如:如果原来是
Wh = [[1, 2], [3, 4], [5, 6]] # N=3
1
那么 Wh_i
会变成:
[
[[1, 2], [1, 2], [1, 2]], # 第0行与(0,0), (0,1), (0,2)配对
[[3, 4], [3, 4], [3, 4]], # 第1行与(1,0), (1,1), (1,2)配对
[[5, 6], [5, 6], [5, 6]]
]
1
2
3
4
5
2
3
4
5
# 第二步:
Wh_j = Wh.unsqueeze(0).expand(N, -1, -1) # [1, N, F'] -> [N, N, F']
1
解释:
Wh.unsqueeze(0)
把它变成[1, N, F']
expand(N, -1, -1)
复制成[N, N, F']
- 每一列变成了:第 列是 的拷贝,与所有 配对 继续上例:
[
[[1, 2], [3, 4], [5, 6]], # 第0列表示 (0,0), (0,1), (0,2)
[[1, 2], [3, 4], [5, 6]], # 第1列表示 (1,0), (1,1), (1,2)
[[1, 2], [3, 4], [5, 6]]
]
1
2
3
4
5
2
3
4
5
# 第三步:
Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1) # [N, N, 2F']
1
解释:
- 把
[Wh_i, Wh_j]
在最后一维拼接起来 - 得到每个 对的拼接向量
所以最终我们得到一个形状为
[N, N, 2F']
的张量Wh_cat
,其中每一对 都包含了节点 和 的特征拼接。
[
[[1, 2, 1, 2], [1, 2, 3, 4], [1, 2, 5, 6]], # 第0行与(0,0), (0,1), (0,2)配对
[[3, 4, 1, 2], [3, 4, 3, 4], [3, 4, 5, 6]], # 第1行与(1,0), (1,1), (1,2)配对
[[5, 6, 1, 2], [5, 6, 3, 4], [5, 6, 5, 6]]
]
1
2
3
4
5
2
3
4
5
# 4.5 最终作用
这个 Wh_cat
会被拿去跟参数 a
做点积:
e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1)) # [N, N]
1
就得到了每一对 的注意力得分 。
# 5 多头GAT 代码实现(带拼接/平均模式):
# 5.1 多头 GAT(Multi-head GAT)的核心思想:
- 多个 GAT heads 并行计算,捕捉不同的“邻接注意力模式”
- 有两种合并方式:
- 拼接(concatenation):用于中间层
- 平均(mean):用于最后输出层
# 5.2 代码实现:多头 GAT 层(带拼接/平均模式)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SingleHeadGATLayer(nn.Module):
def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2):
super().__init__()
self.W = nn.Linear(in_features, out_features, bias=False)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(alpha)
self.dropout = nn.Dropout(dropout)
def forward(self, X, adj):
Wh = self.W(X) # [N, F']
N = Wh.size(0)
Wh_i = Wh.unsqueeze(1).expand(-1, N, -1)
Wh_j = Wh.unsqueeze(0).expand(N, -1, -1)
Wh_cat = torch.cat([Wh_i, Wh_j], dim=-1) # [N, N, 2F']
e = self.leakyrelu(torch.matmul(Wh_cat, self.a).squeeze(-1)) # [N, N]
attention = e.masked_fill(adj == 0, float('-inf'))
attention = F.softmax(attention, dim=1)
attention = self.dropout(attention)
h_prime = torch.matmul(attention, Wh) # [N, F']
return h_prime # 注意:不加激活,留给外层决定
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 5.3 多头封装 GAT 层
class MultiHeadGATLayer(nn.Module):
def __init__(self, in_features, out_features, num_heads=4, dropout=0.0, alpha=0.2, concat=True):
super().__init__()
self.heads = nn.ModuleList([
SingleHeadGATLayer(in_features, out_features, dropout, alpha)
for _ in range(num_heads)
])
self.concat = concat # True: 拼接;False: 平均
def forward(self, X, adj):
out = [head(X, adj) for head in self.heads] # list of [N, F']
if self.concat:
return F.elu(torch.cat(out, dim=1)) # [N, num_heads * F']
else:
return F.elu(torch.mean(torch.stack(out), dim=0)) # [N, F']
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 5.4 使用示例
# 输入特征:4个节点,每个2维
X = torch.tensor([
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
[0.0, 0.0],
])
# 邻接矩阵(含自环)
A_hat = torch.tensor([
[1, 1, 0, 1],
[1, 1, 1, 0],
[0, 1, 1, 1],
[1, 0, 1, 1],
], dtype=torch.float32)
# 多头 GAT:4 个头,每个输出维度为 3,拼接成 [N, 12]
gat_multi = MultiHeadGATLayer(in_features=2, out_features=3, num_heads=4, concat=True)
output = gat_multi(X, A_hat)
print("多头 GAT 输出形状:", output.shape)
print(output)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
上次更新: 2025/07/15, 15:17:12