transformer-viewer

simple text visualization for transformer


License
MIT
Install
pip install transformer-viewer==0.0.2

Documentation

Transformer Viewer

Simple visualization for pytorch model, Test version for classification task
Implementation of paper: Axiomatic Attribution for Deep Networks

Requirments:

  • python > 3.6
  • pytorch > 1.4
  • Colr > 0.9

Installation:

pip install transformer_viewer

How to use:

import torch
import json

from transformer_viewer import Glimpse


# Load dict
with open(PATH_ID2LABEL) as infile: id2label = json.load(infile)
with open(PATH_ID2WORD) as infile: id2word = json.load(infile)
word2id = dict()
for k, v in id2word.items():
    word2id[int(v)] = k

# Load model
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH_MODEL))

tokenizer = lambda text: [int(word2id[item]) for item in text.split(" ")]
special_tokens = [1, 2] # 1 for <eos>, 2 for <pad>

viewer = Glimpse(model, "embeddings", id2word, id2label, tokenizer, special_tokens, loss_pos=0)

viewer.color_bar()

color bar

viewer.view("乌鲁木齐市 新增 一处 城市 中心 旅游 目的地", "travel")

true example

viewer.view("郭晶晶 曾 撮合 吴敏霞 与 章子怡 前男友 , 拒绝 豪门 平淡 才 是 真", "sports")

wrong example

Parameters:

Glimplse(model, embed_name, id2word, id2label, tokenizer, special_tokens, loss_pos=None, step=20)

parameter type description example
model object pytorch model
embed_name str name of the embedding layer 'embeddings'
id2word dict from id to token {1: '你好', 2, '再见'}
id2label dict from id to label {1: 'sports', 2, 'travel'}
tokenizer function which can convert a text to a index list split
special_tokens list ids of the specical tokens [1, 2]
loss_pos int position of loss for the output of model 0

view(text, label)

parameter type description example
text str input text '我 爱 中国'
label int or str 'car'