gnn-layers

Some custom GNN layers for PyTorch Geometric


Keywords
gnn, graph-neural-network, convolution, pooling, pytorch, graph
License
MIT
Install
pip install gnn-layers==1.0.2

Documentation

custom-gnn-layers

This is a little collection of the custom graph convolutional and pooling layers I've made for various projects. Everything here is built on the PyTorch Geometric library and can be used like a regular PyTorch module.

Convolutional Layers

EdgeAttentionConv

EdgeAttentionConv is an edge-conditioned filter with an attention mechanism. It's the same as NNConv, except an attention coefficient for each message is calculated from the edge features. The idea is that messages from some neighbors may be more important than others, depending on their connection with the root node. Node embeddings are updated like so:

where Wr and Wg are trainable weight matrices, and he is a neural network (e.g. a MLP). Wr is used to transform the root node features and Wg is used to calculate an attention coefficient.

Parameters:

  • in_channels (int): Size of each input node embedding.
  • out_channels (int): Size of each output node embedding.
  • edge_nn (torch.nn.Module): A neural network he that maps edge features edge_attr of shape [-1, num_edge_features] to shape [-1, in_channels * out_channels]
  • root_weight (bool, optional): If set to False, the layer will not add the transformed root node features to the output. (default: True)
  • bias (bool, optional): If set to False, the layer will not learn an additive bias. (default: True)

Example:

import torch
from gnn_layers import EdgeAttentionConv

# Convolutional layer
conv = EdgeAttentionConv(
  in_channels=1,
  out_channels=4,
  edge_nn=torch.nn.Linear(2, 1 * 4),
)

# Your input graph
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)

# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4]

ImageConv

ImageConv is an edge-conditioned filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, rather than vectors. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv2d or torch.nn.Conv3d, to transform the node features. Node embeddings are updated like so:

where φr and φm are convolutional layers, and We is a weight matrix.

Parameters:

  • in_channels (int): Number of channels in the input node image.
  • out_channels (int): Number of channels in the output node image.
  • image_dims (tuple): Dimensions of the input node image as a tuple, e.g. for a 4x4 image, set to (4, 4).
  • kernel_size (tuple): Size of the convolving kernel.
  • num_edge_attr (int): Number of edge features.
  • bias (bool, optional): If set to False, the layer will not learn an additive bias. (default: True)
  • aggr (str, optional): The aggregation scheme to use ("add", "mean", "max"). (default: "add")
  • **kwargs (optional): Additional arguments for torch.nn.Conv1d, torch.nn.Conv2d, or torch.nn.Conv3d

Example:

import torch
from gnn_layers import ImageConv

# Convolutional layer
conv = ImageConv(
  in_channels=1,
  out_channels=4,
  image_dims=(8, 8),
  kernel_size=(2, 2),
  num_edge_attr=2
)

# Your input graph
x = torch.randn((3, 1, 8, 8), dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)

# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4, 7, 7]

Pooling Layers

GlobalAttentionPool

To-do.