ws-benchmark

a weak supervision learning benchmark


Keywords
machine-learning, ai, weak-supervision, benchmark-framework, data-centric-ai, data-programming, dataset, deep-learning, nlp, robust-learning, sequence-labeling, weakly-supervised-learning
License
Apache-2.0
Install
pip install ws-benchmark==1.1.2rc0

Documentation

made-with-python Maintenance license repo size Total lines visitors GitHub stars GitHub forks Arxiv

🔧 New

3/18/22

  1. Wrench is available on ws-benchmark now, using pip install ws-benchmark to qucik install.

2/13/22

  1. Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.

11/04/21

  1. (beta) Add parallel_fit for torch model to support pytorch DistributedDataParallel-example

10/15/21

  1. A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
  2. Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the datasets folder)

🔧 What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

@inproceedings{
zhang2021wrench,
title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2021},
url={https://openreview.net/forum?id=Q9SKS5k8io}
}

🔧 What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

For more context, please check out this survey.

To track recent advances in weak supervision, please follow this repo.

🔧 Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

🔧 Available Datasets

The datasets can be downloaded via this.

Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.

A documentation of dataset format and usage can be found in this wiki-page

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link
DomainNet image classification - - - - - link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

🔧 Available Models

If you find any of the implementations is wrong/problematic, don't hesitate to raise issue/pull request, we really appreciate it!

TODO-list: check this out!

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
BERT End Model link link
COSINE End Model link link
Denoise Joint Model link link
WeaSEL Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
BERT-CRF End Model link link
LSTM-ConNet Joint Model link link
BERT-ConNet Joint Model link link

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

methods from related domains:

Robust Learning methods as end model:

Model Model Type Reference Link to Wrench
Meta-Weight-Net End Model link link
Learning2ReWeight End Model link link

Semi-Supervised Learning methods as end model:

Model Model Type Reference Link to Wrench
MeanTeacher End Model link link

Weak Supervision with cleaned labels (Semi-Weak Supervision):

Model Model Type Reference Link to Wrench
ImplyLoss Joint Model link link
ASTRA Joint Model link link

🔧 Quick examples

🔧 Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

For detailed guidance of grid_search, please check out this wiki page.

🔧 Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

🔧 Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

🔧 Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')

🔧 Contact

Contact person: Jieyu Zhang, jieyuzhang97@gmail.com

Don't hesitate to send us an e-mail if you have any question.

We're also open to any collaboration!

🔧 Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!