CommunityNet
CommunityNet
is a hierarchical Graph Neural Network (GNN) designed for graph datasets with community structure (e.g. social networks, molecules, etc.). It's designed to encode information at both the within-community level and the inter-community level.
Installation
You can download CommunityNet
from PyPi:
$ pip install communitynet
Usage
Before instantiating CommunityNet
, you must define a "base" GNN and an "output" GNN. The base GNN is used to create vector embeddings of each community in an input graph. These embeddings are used as node features in an "inter-community" graph, where each node represents a community and each edge is the mean of the edges between two communities. This graph is submitted to the output GNN to make a prediction. Both GNNs can be constructed using the GraphNet
and MLP
PyTorch modules supplied by the library. For example, to construct the CommunityNet
shown in the diagram above, you can do the following:
import torch.nn as nn
from communitynet import GraphNet, MLP, CommunityNet
# Example numbers (arbitrary)
num_node_features = 4
num_edge_features = 2
base_gnn = GraphNet(in_channels=num_node_features, out_channels=8,
num_edge_features=num_edge_features)
output_gnn = nn.Sequential(
GraphNet(in_channels=8, out_channels=4, num_edge_features=num_edge_features),
MLP(in_channels=4, out_channels=1)
)
community_net = CommunityNet(base_gnn, output_gnn, num_communities=3)
GraphNet
and MLP
both have additional hyperparameters (e.g. hidden layers, dropout, etc.) which are described in the reference below. The CommunityNet
class itself derives from torch.nn.Module
, so it can be trained like any other PyTorch model.
Each graph you submit to CommunityNet
must be an instance of torch_geometric.data.Data
with an additional communities
attribute. data.communities
should hold a list of communities, where each community is a set of node indices. For example:
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 0, 2, 1, 2, 2, 3, 3, 4, 3, 5, 4, 5],
[1, 0, 2, 0, 2, 1, 3, 2, 4, 3, 5, 3, 5, 4]],
dtype=torch.long)
x = torch.tensor([[-1], [0], [1], [0.5], [0.75], [-0.25]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
data.communities = [{0, 1, 2}, {3, 4, 5}]
Note that every graph in your dataset must have the same number of communities.
Reference
GraphNet
PyTorch module that implements a GNN. Uses NNConv
(an edge-conditioned convolutional operator) as a filter and global pooling to convert a graph into a vector embedding.
Parameters:
-
in_channels
(int): Number of node features -
out_channels
(int): Number of output features -
num_edge_features
(int): Number of edge features -
hidden_channels
(list, optional (default=[])): List of hidden state sizes; length of list == number of layers -
use_pooling
(bool, optional (default=False)): Whether or not to use top-k pooling -
dropout_prob
(float, optional (default=0.0)): Dropout probability applied to each GNN layer -
global_pooling
(str, optional (default="mean")): Global pooling mode; options are: "mean", "add", and "max" -
activation
(torch.nn.Module, optional (default=None)): Activation function used forNNConv
-
edge_nn_kwargs
(dict, optional (default={})): Dictionary of parameters for the MLP used to process edge features inNNConv
MLP
PyTorch module that implements a multi-layer perceptron. This can be used in an output GNN to convert a graph embedding into a prediction (e.g. a classification/regression).
Parameters:
-
in_channels
(int): Number of input features -
out_channels
(int): Number of output features -
hidden_channels
(list, optional (default=[])): List of hidden state sizes; length of list == number of layers -
h_activation
(torch.nn.Module, optional (default=None)): Hidden activation function -
out_activation
(torch.nn.Module, optional (default=None)): Output activation function
CommunityNet
PyTorch module that implements a hierarchical GNN.
Parameters:
-
base_gnn
(torch.nn.Module): Base GNN used to process each community -
output_gnn
(torch.nn.Module): Output GNN used to process the inter-community graph and produce a prediction -
num_communities
(int): Number of communities in each input graph -
num_jobs
(int, optional (default=1)): Number of jobs (CPU cores) to distribute the community embedding work across