XLNet implemented in Keras


Keywords
glue, keras, language-model, nlp, xlnet
License
MIT
Install
pip install keras-xlnet==0.16.0

Documentation

Keras XLNet

Version License

[中文|English]

Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.

Install

pip install keras-xlnet

Usage

Fine-tuning on GLUE

Click the task name to see the demos with base model:

Task Name Metrics Approximate Results on Dev Set
CoLA Matthew Corr. 52
SST-2 Accuracy 93
MRPC Accuracy/F1 86/89
STS-B Pearson Corr. / Spearman Corr. 86/87
QQP Accuracy/F1 90/86
MNLI Accuracy 84/84
QNLI Accuracy 86
RTE Accuracy 64
WNLI Accuracy 56

(Only 0s are predicted in WNLI dataset)

Load Pretrained Checkpoints

import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
    config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
    checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
    batch_size=16,
    memory_len=512,
    target_len=128,
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_BI,
)
model.summary()

Arguments batch_size, memory_len and target_len are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase is True, otherwise a model used for fine-tuning will be returned.

About I/O

Note that shuffle should be False in either fit or fit_generator if memories are used.

in_train_phase is False

3 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).

1 output:

  • The feature for each token, with shape (batch_size, target_len, units).

in_train_phase is True

4 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).
  • Masks of tokens, with shape (batch_size, target_len).

1 output:

  • The probability of each token in each position, with shape (batch_size, target_len, num_token).