transformers-visualizer

Explain your 🤗 transformers without effort! Display the internal behavior of your model.


Keywords
machine, learning, natural, language, processing, nlp, explainability, transformers, model, interpretability, ai, explainable-ai, huggingface, huggingface-transformers, transformer
License
Apache-2.0
Install
pip install transformers-visualizer==0.2.2

Documentation

Transformers visualizer

Explain your 🤗 transformers without effort!

Apache PyPI - Python Version PyPI - Package Version

Transformers visualizer is a python package designed to work with the 🤗 transformers package. Given a model and a tokenizer, this package supports multiple ways to explain your model by plotting its internal behavior.

This package is mostly based on the Captum tutorials [1] [2].

Installation

pip install transformers-visualizer

Quickstart

Let's define a model, a tokenizer and a text input for the following examples.

from transformers import AutoModel, AutoTokenizer

model_name = "bert-base-uncased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder."

Visualizers

Attention matrices of a specific layer

from transformers_visualizer import TokenToTokenAttentions

visualizer = TokenToTokenAttentions(model, tokenizer)
visualizer(text)

Instead of using __call__ function, you can use the compute method. Both work in place, compute method allows chaining method.

plot method accept a layer index as parameter to specify which part of your model you want to plot. By default, the last layer is plotted.

import matplotlib.pyplot as plt

visualizer.plot(layer_index = 6)
plt.savefig("token_to_token.jpg")

token to token

Attention matrices normalized across head axis

You can specify the order used in torch.linalg.norm in __call__ and compute methods. By default, an L2 norm is applied.

from transformers_visualizer import TokenToTokenNormalizedAttentions

visualizer = TokenToTokenNormalizedAttentions(model, tokenizer)
visualizer.compute(text).plot()

normalized token to token

Plotting

plot method accept to skip special tokens with the parameter skip_special_tokens, by default it's set to False.

You can use the following imports to use plotting functions directly.

from transformers_visualizer.plotting import plot_token_to_token, plot_token_to_token_specific_dimension

These functions or the plot method of a visualizer can use the following parameters.

  • figsize (Tuple[int, int]): Figsize of the plot. Defaults to (20, 20).
  • ticks_fontsize (int): Ticks fontsize. Defaults to 7.
  • title_fontsize (int): Title fontsize. Defaults to 9.
  • cmap (str): Colormap. Defaults to "viridis".
  • colorbar (bool): Display colorbars. Defaults to True.

Upcoming features

  • Add an option to mask special tokens.
  • Add an option to specify head/layer indices to plot.
  • Add other plotting backends such as Plotly, Bokeh, Altair.
  • Implement other visualizers such as vector norm.

References

  • [1] Captum's BERT tutorial (part 1)
  • [2] Captum's BERT tutorial (part 2)

Acknowledgements