area-attention

PyTorch implementation of Area Attention


Keywords
artificial, intelligence, area, attention
License
MIT
Install
pip install area-attention==0.1.0

Documentation

image

Area Attention

PyTorch implementation of Area Attention [1]. This module allows to attend to areas of the memory, where each area contains a group of items that are either spatially or temporally adjacent. TensorFlow implementation can be found here.

Setup

$ pip install area_attention

Usage

Single-head Area Attention:

import torch

from area_attention import AreaAttention

area_attention = AreaAttention(
    key_query_size=32,
    area_key_mode='max',
    area_value_mode='mean',
    max_area_height=2,
    max_area_width=2,
    memory_height=4,
    memory_width=4,
    dropout_rate=0.2,
    top_k_areas=0
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = area_attention(q, k, v)
x  # torch.Tensor with shape (4, 8, 64)

Multi-head Area Attention:

import torch

from area_attention import AreaAttention, MultiHeadAreaAttention

area_attention = AreaAttention(
    key_query_size=32,
    area_key_mode='max',
    area_value_mode='mean',
    max_area_height=2,
    max_area_width=2,
    memory_height=4,
    memory_width=4,
    dropout_rate=0.2,
    top_k_areas=0
)
multi_head_area_attention = MultiHeadAreaAttention(
    area_attention=area_attention,
    num_heads=2,
    key_query_size=32,
    key_query_size_hidden=32,
    value_size=64,
    value_size_hidden=64
)
q = torch.rand(4, 8, 32)
k = torch.rand(4, 16, 32)
v = torch.rand(4, 16, 64)
x = multi_head_area_attention(q, k, v)
x  # torch.Tensor with shape (4, 8, 64)

Unit tests

$ python -m pytest tests

Bibliography

[1] Li, Yang, et al. "Area attention." International Conference on Machine Learning. PMLR, 2019.

Citations

@inproceedings{li2019area,
  title={Area attention},
  author={Li, Yang and Kaiser, Lukasz and Bengio, Samy and Si, Si},
  booktitle={International Conference on Machine Learning},
  pages={3846--3855},
  year={2019},
  organization={PMLR}
}