critical-path

Tools for adapting universal language models to specifc tasks


Keywords
BERT, google, transformer, squad, nlp
License
Apache-2.0
Install
pip install critical-path==0.0.9

Documentation

critical-path-nlp

Tools for adapting universal language models to specifc tasks

Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.

Adapted from: Google's BERT

Installation

  1. pip install critical_path
  2. Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
  3. Unzip the model and make note of the path

Examples

Current Capabilities

BERT for Question Answering

BERT for Multi-Label Classification

BERT for Single-Label Classification

Future Capabilities

GPT-2 Training and Generation

Usage + Core Components

Configuring BERT

First, define the model paths

base_model_folder_path = "../models/uncased_L-12_H-768_A-12/"  # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json"  # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt"  # Located inside the Base Model folder

output_directory = "../models/trained_BERT/" # Trained model and results landing folder

# Multi-Label and Single-Label Specific
data_dir = None  # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure

Second, define the model run parameters

"""Settable parameters and their default values

Note: Most default values are perfectly fine
"""

# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True   

# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5

# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False  # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0

Initialize the configuration handler

from critical_path.BERT.configs import ConfigClassifier

Flags = ConfigClassifier()
Flags.set_model_paths(
    bert_config_file=base_model_folder_path + name_of_config_json_file,
    bert_vocab_file=base_model_folder_path + name_of_vocab_file,
    bert_output_dir=output_folder_path,
    data_dir=data_dir)

Flags.set_model_params(
    batch_size_train=8, 
    max_seq_length=256,
    num_train_epochs=3)

# Retrieve a handle for the configs
FLAGS = Flags.get_handle()

A single 1070GTX using BERT-Base Uncased can handle

Model max_seq_len batch_size
BERT-Base Uncased 256 6
... 384 4

For full batch size and sequence length guidelines see Google's recommendations

Using Configured Model

First, create a new model with the configured parameters

"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier

model = MultiLabelClassifier(FLAGS)

Second, load your data source

  • SQuAD has dedicated dataloaders
  • Multi-Label Classification has a generic dataloader
    • DataProcessor in /BERT/model_multilabel_class
      • Note: This requires data labels to be in string format
      • labels = [
            ["label_1", "label_2", "label_3"],
            ["label_2"]
        ]
  • Single-Label Classification dataloaders
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor

# read_data is dataset specifc - see /bert_multilabel_example.py
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)

processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(
        input_ids=input_ids,
        input_text=input_text,
        input_labels=input_labels,
        set_type='train')

Third, run your task

"""Train and predict a Multi-Label Classifier"""

if do_train:
  model.train(train_examples, label_list)
  
if do_predict:
  model.predict(predict_examples, label_list)

For full examples please see: