Python实现Graph Neural Network(GNN)中的消息传递聚合函数定制

Python实现GNN中的消息传递聚合函数定制

大家好,今天我们来深入探讨图神经网络(GNN)中消息传递聚合函数的定制。GNN的核心在于通过节点间的信息传递来学习节点和图的表示。而消息传递的聚合阶段,是将邻居节点的信息汇总的关键步骤。理解和定制这个过程,能让我们更好地控制GNN的行为,使其适应各种复杂的图结构和学习任务。

GNN的消息传递范式

首先,让我们简单回顾一下GNN的消息传递范式。一个典型的消息传递过程包含三个主要步骤:

  1. 消息传递(Message Passing): 每个节点根据其邻居节点的特征生成消息。
  2. 消息聚合(Aggregation): 每个节点收集并聚合来自其邻居节点的消息。
  3. 节点更新(Node Update): 每个节点利用聚合后的消息更新自身的表示。

这三个步骤可以迭代多次,使得节点能够逐步感知到更远距离的节点信息。今天我们的重点是消息聚合这一步,探讨如何通过Python定制聚合函数,实现更灵活的消息处理。

常见的聚合函数及其局限性

在标准的GNN库(如PyTorch Geometric, DGL)中,通常提供了一些预定义的聚合函数,例如:

  • Sum (Summation): 将所有邻居节点的消息求和。
  • Mean (Average): 计算所有邻居节点消息的平均值。
  • Max (Maximum): 选择邻居节点消息中的最大值。
  • Min (Minimum): 选择邻居节点消息中的最小值。

这些聚合函数在很多情况下都表现良好,但它们也存在一些局限性:

  • 对所有邻居节点同等对待: 它们没有考虑到不同邻居节点的重要性可能不同。
  • 无法捕捉复杂的节点关系: 简单地求和、平均或取最大/最小值可能丢失邻居节点之间的复杂关系。
  • 可能对异常值敏感: Sum、Mean等聚合函数容易受到异常值的影响。

为了克服这些局限性,我们需要定制自己的聚合函数。

使用Python定制聚合函数

下面我们以PyTorch Geometric为例,演示如何使用Python定制GNN中的消息传递聚合函数。

首先,我们需要了解PyTorch Geometric中消息传递的机制。PyTorch Geometric提供了一个基类torch_geometric.nn.MessagePassing,我们可以通过继承这个类来定义自己的GNN层。

1. 定义GNN层

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, aggregator_type='sum'):
        super(CustomGNNLayer, self).__init__(aggr=aggregator_type, node_dim=-1)  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)  # Transformation of node feature matrix.

    def forward(self, x, edge_index):
        # x: [N, in_channels]
        # edge_index: [2, E]

        # Step 1: Add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, edge_index, num_nodes):
        # x_i: [E, out_channels] Source node features
        # x_j: [E, out_channels] Target node features
        # edge_index: [2, E]
        # num_nodes: int
        # Step 3: Message Passing
        return x_j

    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs: [E, out_channels]
        # index: [E]
        # ptr: [N+1] (optional)
        # dim_size: int (optional)
        # Step 4: Aggregate messages. This will be replaced with custom aggregation
        # For demonstration, let's keep the default aggregation
        return super().aggregate(inputs, index, ptr, dim_size)

在这个例子中,CustomGNNLayer继承了MessagePassing类。forward函数负责执行消息传递的整个过程。message函数定义了如何从邻居节点生成消息。而aggregate函数则是我们定制聚合函数的地方。

2. 定制聚合函数

PyTorch Geometric的MessagePassing类提供了aggregate函数的默认实现,它会根据__init__函数中指定的aggr参数来选择聚合函数(例如’sum’, ‘mean’, ‘max’)。

要定制聚合函数,我们需要重写aggregate函数。aggregate函数的输入包括:

  • inputs: 所有边的消息。
  • index: 每个消息的目标节点索引。
  • ptr (可选): 用于批处理图数据的指针。
  • dim_size (可选): 节点数量。

下面是一些定制聚合函数的示例:

示例 1: 基于节点度的加权平均

这个例子中,我们根据目标节点的度来加权平均邻居节点的消息。度数高的节点的消息权重更高。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class DegreeWeightedAvgGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DegreeWeightedAvgGNNLayer, self).__init__(aggr='add', node_dim=-1)  # We will manually perform averaging
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index[0], edge_index[1]
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j: [E, out_channels]
        # norm: [E]
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out

    # The following aggregate function is implicitly used due to aggr = 'add' in __init__
    # def aggregate(self, inputs, index, ptr = None, dim_size = None):
    #     return super().aggregate(inputs, index, ptr, dim_size)

在这个例子中,我们使用 norm 来表示边的权重,该权重由源节点和目标节点的度数决定。message 函数将节点特征乘以权重 norm,然后在 aggregate 步骤中,aggr='add' 会将所有加权的消息相加。最后,update 函数返回聚合结果。

示例 2: 基于注意力机制的聚合

这个例子中,我们使用注意力机制来动态地计算邻居节点的重要性。

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

class AttentionAggregationGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1):
        super(AttentionAggregationGNNLayer, self).__init__(aggr='add', node_dim=-1)
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.attn_l = torch.nn.Linear(out_channels, heads)
        self.attn_r = torch.nn.Linear(out_channels, heads)
        self.heads = heads

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, edge_index):
        # Calculate attention coefficients
        attn_i = self.attn_l(x_i)
        attn_j = self.attn_r(x_j)
        attn = attn_i + attn_j  # [E, heads]
        attn = F.leaky_relu(attn, negative_slope=0.2)  # Applying LeakyReLU
        attn = torch.exp(attn)  # Applying exponential to ensure positivity

        # Normalize attention coefficients
        row, col = edge_index[0], edge_index[1]
        deg = degree(row, x.size(0), dtype=x_i.dtype) # Sum attention values for each target node
        attn_sum = torch.zeros_like(deg)
        attn_sum = attn_sum.scatter_add_(0, row, attn) # Summing attention values for each target node.
        attn_sum[attn_sum == 0] = 1e-16 # avoid division by zero
        attn_norm = attn / attn_sum[row].unsqueeze(1)

        # Apply attention to the message
        return x_j * attn_norm  # [E, out_channels]

    def update(self, aggr_out):
        return aggr_out

在这个例子中,message 函数计算了每条边的注意力系数,并将其应用于消息。aggregate 函数(在这里是默认的 add)会将所有加权的消息相加。注意,我们使用了 F.leaky_relu 来保证注意力系数非负,并使用 torch.exp 来进一步增强差异。

示例 3: 使用LSTM进行序列聚合

如果邻居节点具有某种顺序关系,我们可以使用LSTM来聚合它们的消息。

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

class LSTMAggregationGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(LSTMAggregationGNNLayer, self).__init__(aggr=None, node_dim=-1) # Disable default aggregation
        self.lin = nn.Linear(in_channels, out_channels)
        self.lstm = nn.LSTM(out_channels, out_channels, batch_first=True)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        return self.propagate(edge_index, x=x, edge_index=edge_index)

    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs: [E, out_channels]
        # index: [E]
        # ptr: [N+1] (optional)
        # dim_size: int (optional)
        # Step 4: Aggregate messages using LSTM

        # Sort messages by target node index
        unique_indices, counts = torch.unique(index, return_counts=True)
        output = []
        for i in range(len(unique_indices)):
            node_idx = unique_indices[i]
            num_neighbors = counts[i]
            node_messages = inputs[index == node_idx].unsqueeze(0) # [1, num_neighbors, out_channels]
            lstm_out, _ = self.lstm(node_messages) # [1, num_neighbors, out_channels]
            output.append(lstm_out[:, -1, :])  # Take the last hidden state as the aggregated message

        # Pad with zeros if some nodes have no neighbors
        if len(output) < dim_size:
            padding_size = dim_size - len(output)
            padding = torch.zeros(padding_size, inputs.size(-1)).to(inputs.device)
            output.extend([padding])

        output = torch.cat(output, dim=0)
        return output

    def update(self, aggr_out):
        return aggr_out

在这个例子中,我们首先禁用了默认的聚合函数 (aggr=None)。然后在 aggregate 函数中,我们将每个节点的所有邻居节点的消息收集起来,并使用LSTM处理它们。LSTM的最后一个隐状态被用作聚合后的消息。 需要注意的是,这个实现假设每个节点的邻居节点具有某种顺序关系。

3. 使用定制的GNN层

定义好定制的GNN层后,就可以像使用普通的GNN层一样使用它了。

# 假设我们已经有了数据
# data.x: 节点特征
# data.edge_index: 边索引

# 初始化模型
model = DegreeWeightedAvgGNNLayer(in_channels=data.num_node_features, out_channels=64)

# 前向传播
output = model(data.x, data.edge_index)

定制聚合函数的注意事项

在定制聚合函数时,需要注意以下几点:

  • 效率: 聚合函数的效率对GNN的整体性能至关重要。尽量使用高效的张量操作,避免使用循环。
  • 可微性: 为了能够进行反向传播,聚合函数必须是可微的。
  • 内存占用: 聚合函数可能会消耗大量的内存,尤其是在处理大型图数据时。
  • 适用性: 不同的聚合函数适用于不同的图结构和学习任务。需要根据具体情况选择合适的聚合函数。
  • 图结构感知: 尽量利用图的结构信息,例如节点度、节点类型、边的权重等。

不同聚合函数的适用场景

聚合函数类型 描述 适用场景 优点 缺点
Sum/Mean/Max/Min 基本的聚合操作,将所有邻居节点的消息求和、平均或取最大/最小值。 节点特征重要性相似的图,或者需要强调特定特征(例如,最大值强调最重要的特征)。 简单易用,计算效率高。 对所有邻居节点同等对待,无法捕捉复杂的节点关系,Sum和Mean容易受异常值影响。
度加权平均 根据节点的度来加权平均邻居节点的消息,度数高的节点的消息权重更高。 节点度分布不均匀的图,例如社交网络。 能够考虑到不同节点的度数,更好地处理度分布不均匀的图。 需要计算节点度,计算量稍大。
注意力机制 使用注意力机制来动态地计算邻居节点的重要性。 需要区分不同邻居节点重要性的图,例如知识图谱。 能够动态地学习邻居节点的重要性,更好地捕捉复杂的节点关系。 计算复杂度较高,需要更多的训练数据。
LSTM序列聚合 使用LSTM来聚合邻居节点的消息,适用于邻居节点具有某种顺序关系的图。 邻居节点具有顺序关系的图,例如时间序列图。 能够利用邻居节点的顺序关系,更好地捕捉时序信息。 计算复杂度很高,需要更多的训练数据,并且需要确定邻居节点的顺序。
基于图结构的聚合函数 利用图的结构信息(例如节点类型、边类型、子图结构等)来定制聚合函数。 具有复杂图结构的图,例如知识图谱、生物网络。 能够充分利用图的结构信息,更好地学习节点和图的表示。 设计和实现比较复杂,需要对图结构有深入的理解。
基于GNN的聚合函数 使用另一个GNN来学习聚合函数。例如,可以使用一个小型GNN来学习如何将邻居节点的消息聚合到目标节点上。 非常复杂的图结构和学习任务,例如图生成、图匹配。 能够学习非常复杂的聚合函数,具有很强的表达能力。 计算复杂度非常高,需要大量的训练数据。

总结

消息传递聚合函数是GNN的核心组件之一。通过定制聚合函数,我们可以更好地控制GNN的行为,使其适应各种复杂的图结构和学习任务。在定制聚合函数时,需要注意效率、可微性、内存占用和适用性等问题。 希望今天的分享能够帮助大家更好地理解和应用GNN。

下一步的方向

未来,我们可以继续探索更高级的聚合函数,例如:

  • 基于图神经网络的聚合函数: 使用另一个GNN来学习聚合函数,实现更复杂的聚合逻辑。
  • 自适应聚合函数: 根据不同的节点和边动态地选择不同的聚合函数。
  • 可解释的聚合函数: 设计可解释的聚合函数,帮助我们理解GNN的决策过程。

定制聚合函数是GNN研究的一个重要方向,相信未来会有更多有趣和有用的方法涌现出来。

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注