slot-attention

Implementation of Slot Attention in Pytorch


Keywords
attention, artificial, intelligence, artificial-intelligence, attention-mechanism, clustering, deep-learning
License
MIT
Install
pip install slot-attention==0.0.6

Documentation

Slot Attention

Implementation of Slot Attention from GoogleAI in Pytorch.

Experiments pending.

Install

$ pip install slot_attention

Usage

import torch
from slot_attention import SlotAttention

slot_attn = SlotAttention(
    num_slots = 5,
    dim = 512,
    iters = 3   # iterations of attention, defaults to 3
)

inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)

After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.

slot_attn(inputs, num_slots = 8) # (2, 8, 512)

Citation

@misc{locatello2020objectcentric,
    title = {Object-Centric Learning with Slot Attention},
    author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
    year = {2020},
    eprint = {2006.15055},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}