custom-gnn-layers
This is a little collection of the custom graph convolutional and pooling layers I've made for various projects. Everything here is built on the PyTorch Geometric library and can be used like a regular PyTorch module.
Convolutional Layers
EdgeAttentionConv
EdgeAttentionConv
is an edge-conditioned filter with an attention mechanism. It's the same as NNConv
, except an attention coefficient for each message is calculated from the edge features. The idea is that messages from some neighbors may be more important than others, depending on their connection with the root node. Node embeddings are updated like so:
where Wr and Wg are trainable weight matrices, and he is a neural network (e.g. a MLP). Wr is used to transform the root node features and Wg is used to calculate an attention coefficient.
Parameters:
- in_channels (int): Size of each input node embedding.
- out_channels (int): Size of each output node embedding.
-
edge_nn (torch.nn.Module): A neural network he that maps edge features
edge_attr
of shape[-1, num_edge_features]
to shape[-1, in_channels * out_channels]
-
root_weight (bool, optional): If set to
False
, the layer will not add the transformed root node features to the output. (default:True
) -
bias (bool, optional): If set to
False
, the layer will not learn an additive bias. (default:True
)
Example:
import torch
from gnn_layers import EdgeAttentionConv
# Convolutional layer
conv = EdgeAttentionConv(
in_channels=1,
out_channels=4,
edge_nn=torch.nn.Linear(2, 1 * 4),
)
# Your input graph
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)
# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4]
ImageConv
ImageConv
is an edge-conditioned filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, rather than vectors. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv2d
or torch.nn.Conv3d
, to transform the node features. Node embeddings are updated like so:
where φr and φm are convolutional layers, and We is a weight matrix.
Parameters:
- in_channels (int): Number of channels in the input node image.
- out_channels (int): Number of channels in the output node image.
-
image_dims (tuple): Dimensions of the input node image as a tuple, e.g. for a 4x4 image, set to
(4, 4)
. - kernel_size (tuple): Size of the convolving kernel.
- num_edge_attr (int): Number of edge features.
-
bias (bool, optional): If set to
False
, the layer will not learn an additive bias. (default:True
) -
aggr (str, optional): The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"add"
) -
**kwargs (optional): Additional arguments for
torch.nn.Conv1d
,torch.nn.Conv2d
, ortorch.nn.Conv3d
Example:
import torch
from gnn_layers import ImageConv
# Convolutional layer
conv = ImageConv(
in_channels=1,
out_channels=4,
image_dims=(8, 8),
kernel_size=(2, 2),
num_edge_attr=2
)
# Your input graph
x = torch.randn((3, 1, 8, 8), dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)
# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4, 7, 7]
Pooling Layers
GlobalAttentionPool
To-do.