WRENCH: Weak supeRvision bENCHmark

Overview

made-with-python Maintenance Open Source Love svg1

🔧 What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications: (coming soon!)

🔧 What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

To track recent advances in weak supervision, please follow this repo.

🔧 Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

🔧 Available Datasets

The datasets can be downloaded via this.

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

🔧 Available Models

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
Pre-trained Language Model End Model link link
COSINE End Model link link
Denoise Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
Pre-trained Language Model End Model link link
ConNet Joint Model link link

🔧 Quick examples

🔧 Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

🔧 Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

🔧 Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

🔧 Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')
Comments
  • ModuleNotFoundError: No module named 'tokenizations'

    ModuleNotFoundError: No module named 'tokenizations'

    Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

    opened by Gnaiqing 12
  • Is there a limitation of using dataset for different algs?

    Is there a limitation of using dataset for different algs?

    Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

        loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
    KeyError: 'labels'
    

    Can this be fixed?

    opened by mrbeann 8
  • Python Package Installation Fails

    Python Package Installation Fails

    Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

    Tested on system:

    • OS: ubuntu
    • Python: 3.8.13
    • Clean VE

    Command to replicate:

    • pip install ws-benchmark

    Stack Trace:

    ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        ws-benchmark 1.1.1 depends on networkx==2.7
        snorkel 0.9.7 depends on networkx<2.4 and >=2.2
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    
    opened by bradleyfowler123 4
  • Using Multiple GPUs

    Using Multiple GPUs

    Hi,

    Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

    Best regards.

    opened by tolgayan 4
  • Running scripts

    Running scripts

    Hi, I am trying to run some models on the IMDB dataset.

    MLP:

    import logging
    import torch
    import numpy as np
    from wrench.dataset import load_dataset
    from wrench.labelmodel import Snorkel
    from wrench.logging import LoggingHandler
    from wrench.search import grid_search
    from wrench.endmodel import EndClassifierModel
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
        #### Search Space
        search_space = {
            'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
            'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
        }
    
        #### Initialize the model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam'
        )
    
        #### Search best hyper-parameters using validation set in parallel
        n_trials = 20
        n_repeats = 1
        searched_paras = grid_search(
            model,
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            direction='auto',
            search_space=search_space,
            n_repeats=n_repeats,
            n_trials=n_trials,
            parallel=True,
            device=device,
        )
    
    
        #### Run end model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam',
            **searched_paras
        )
        model.fit(
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            device=device
        )
    
        logger.info(model.predict(test_data).tolist())
    
        acc = model.test(test_data, 'acc')
        logger.info(f'end model (MLP) test acc: {acc}')
    
    

    for which I am getting the following output:

    100%|██████████| 20000/20000 [00:00<00:00, 902651.16it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 852639.45it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 829503.99it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 20000/20000 [1:42:45<00:00,  3.24it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:24<00:00,  3.11it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:50<00:00,  3.01it/s]
    [I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
    2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:24:36 - label model test acc: 0.716
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    100%|██████████| 1/1 [00:37<00:00, 37.48s/it]
    [I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:23<00:00, 23.70s/it]
    [I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:14<00:00, 14.53s/it]
    [I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:43<00:00, 43.73s/it]
    [I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
    [I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:38<00:00, 38.81s/it]
    [I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:38<00:00, 38.15s/it]
    [I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:15<00:00, 15.47s/it]
    [I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:22<00:00, 22.49s/it]
    [I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:40<00:00, 40.25s/it]
    [I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:39<00:00, 39.06s/it]
    [I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:43<00:00, 43.46s/it]
    [I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:23<00:00, 23.41s/it]
    [I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:22<00:00, 22.12s/it]
    [I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:15<00:00, 15.78s/it]
    [I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:37<00:00, 37.28s/it]
    [I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:16<00:00, 16.04s/it]
    [I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:19<00:00, 19.65s/it]
    [I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
    100%|██████████| 1/1 [00:15<00:00, 15.41s/it]
    [I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
    100%|██████████| 1/1 [00:16<00:00, 16.75s/it]
    [I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
    [TRAIN]:  15%|█████▌                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
    2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
    2021-10-23 22:33:43 - 
    ==========[hyper parameters]==========
    {
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.1,
            "weight_decay": 0.1
        }
    }
    ==========[backbone config]==========
    {
        "name": "MLP",
        "paras": {
            "hidden_size": 100,
            "dropout": 0.0
        }
    }
    
    2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
    2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004
    

    COSINE:

    import logging
    import torch
    from wrench.dataset import load_dataset
    from wrench.logging import LoggingHandler
    from wrench.labelmodel import Snorkel
    from wrench.endmodel import Cosine
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
    
        # COSINE
        model = Cosine(
            teacher_update=100,
            margin=1.0,
            thresh=0.6,
            lr=1e-5,
            mu=1.0,
            lamda=0.05,
            backbone='BERT',
            backbone_model_name=bert_model_name,
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
        )
    
        model.fit(dataset_train=train_data,
                  dataset_valid=valid_data,
                  y_train=aggregated_hard_labels,
                  evaluation_step=10,
                  metric='acc',
                  patience=50,
                  device=device)
    
        acc = model.test(test_data, 'acc')
    
        logger.info(model.predict(test_data))
    
        logger.info(f'end model (COSINE) test acc: {acc}')
    

    for which I am getting the following output:

    100%|██████████| 20000/20000 [00:00<00:00, 899119.81it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 423667.07it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 802645.44it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 20000/20000 [1:47:44<00:00,  3.09it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [14:22<00:00,  2.90it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:33<00:00,  3.07it/s] 
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    [TRAIN] COSINE pretrain stage:   5%|â–Š               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
    [TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
    2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:30:14 - label model test acc: 0.716
    2021-10-23 22:30:17 - 
    ==========[hyper parameters]==========
    {
        "teacher_update": 100,
        "margin": 1.0,
        "mu": 1.0,
        "thresh": 0.6,
        "lamda": 0.05,
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.001,
            "weight_decay": 0.0
        }
    }
    ==========[backbone config]==========
    {
        "name": "BERT",
        "paras": {
            "model_name": "bert-base-cased",
            "max_tokens": 512,
            "fine_tune_layers": -1
        }
    }
    ==========[label model_config config]==========
    {
        "name": "MajorityVoting",
        "paras": {}
    }
    
    2021-10-23 22:51:52 - [INFO] early stop @ step 510!
    2021-10-23 22:55:20 - early stop because all the data are filtered!
    2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
    2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5
    

    As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

    Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

    I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

    Thanks for the benchmark, I really appreciate it!

    opened by viheheb757 4
  • Reproducing Table 11 for classification

    Reproducing Table 11 for classification

    Thanks for this package @JieyuZ2 -- do you happen to have an orchestration script for reproducing Table 11 (and therefore Table 3) in the Wrench paper?

    opened by pmangg 3
  • No module named 'wrench.classification.self_training'

    No module named 'wrench.classification.self_training'

    Hi, I am trying to run run_denoise.py but I am getting the following error:

    Traceback (most recent call last):
      File "run_denoise.py", line 5, in <module>
        from wrench.classification import Denoise
      File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
        from .self_training import LDSelfTrain, DDSelfTrain
    ModuleNotFoundError: No module named 'wrench.classification.self_training'
    

    Could you please add LDSelfTrain and DDSelfTrain classes?

    opened by andreaspung 3
  • Questions on the use of ground-truth labels for validation

    Questions on the use of ground-truth labels for validation

    Thanks for putting up the benchmark! This is really great work! It seems that both the label model and the end model use the ground-truth labels for validation. For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights: https://github.com/JieyuZ2/wrench/blob/544119e781d010797cf153307aa1090361c99522/wrench/basemodel.py#L286 I have a few questions regarding this: (1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models. (2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

    I appreciate any explanations. Thanks!

    opened by wurenzhi 2
  • Clarifying dataset download links

    Clarifying dataset download links

    Great work on the benchmark!

    Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

    One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

    opened by jason-fries 2
  • Fix retained probabilities

    Fix retained probabilities

    This pull request removes a bug which lead to the wrong probabilities being stored along with the predictions of each labeling function.

    Previously, all probabilities (2d tensor of size batch by classes) were saved alongside the class predictions. However, what was supposed to be saved is the probability associated with each prediction of the model.

    opened by benbo 2
  • New Release

    New Release

    Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

    I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

    opened by rsmith49 2
  • Numba 0.43 doesn't work with newer Python versions

    Numba 0.43 doesn't work with newer Python versions

    The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue. Traceback:

    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
      warnings.warn(
    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
      warnings.warn(
    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
        from .dawid_skene import DawidSkene
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
        from numba import njit, prange
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
        from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
        from .targets import registry
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
        from . import cpu
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
        from numba import _dynfunc, config
    ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK
    
    opened by susuzheng 0
  • COSINE for token classification?

    COSINE for token classification?

    Hi,

    I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

    opened by KrishnanJothi 0
  • Balance in Dawid Skene is obsolete.

    Balance in Dawid Skene is obsolete.

    https://github.com/JieyuZ2/wrench/blob/6d8397956533fc6c2fe50e93fcfe0a2303bdd05f/wrench/labelmodel/dawid_skene.py#L55

    I realized this balance variable is used nowhere in this file. If it is intended, I think it should be removed from input parameters.

    opened by ch-shin 1
  • Balance sum to 1

    Balance sum to 1

    https://github.com/JieyuZ2/wrench/blob/ab717ac26a76649c8fdb946a28dffe7e682c80ba/wrench/basemodel.py#L303

    Hi, I find a minor issue that the class prior computed by this function does not sum to 1. Hope you can revise it.

    opened by Gnaiqing 0
  • about COSINE endmodel

    about COSINE endmodel

    Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

    opened by Gnaiqing 0
  • Recommended parameters to use for each algorithms and datasets.

    Recommended parameters to use for each algorithms and datasets.

    I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper. I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

    opened by mrbeann 0
Releases(v1.1)
  • v1.1(Nov 9, 2021)

    What's new:

    • A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
    • A new EndClassifierModel model which unifies all the classification backbones
    • Two new datasets on image classification
    • Support torch native amp for inference in the validation step
    • Support training on multiple GPUS via torch's DistributedDataParallel and the new parallel_fit function
    • fixed some bugs
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Sep 7, 2021)

Owner
Jieyu Zhang
CS PhD
Jieyu Zhang
An original implementation of "MetaICL Learning to Learn In Context" by Sewon Min, Mike Lewis, Luke Zettlemoyer and Hannaneh Hajishirzi

MetaICL: Learning to Learn In Context This includes an original implementation of "MetaICL: Learning to Learn In Context" by Sewon Min, Mike Lewis, Lu

Meta Research 141 Jan 07, 2023
Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer

AdaConv Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer from "Adaptive Convolutions for Structure-

65 Dec 22, 2022
The official PyTorch code for NeurIPS 2021 ML4AD Paper, "Does Thermal data make the detection systems more reliable?"

MultiModal-Collaborative (MMC) Learning Framework for integrating RGB and Thermal spectral modalities This is the official code for NeurIPS 2021 Machi

NeurAI 12 Nov 02, 2022
Pure python implementation reverse-mode automatic differentiation

MiniGrad A minimal implementation of reverse-mode automatic differentiation (a.k.a. autograd / backpropagation) in pure Python. Inspired by Andrej Kar

Kenny Song 76 Sep 12, 2022
RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds

RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds This repository contains the code asscoiated

Felix Hensel 14 Dec 12, 2022
[NeurIPS 2021] Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training

Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training Code for NeurIPS 2021 paper "Better Safe Than Sorry: Preventing Delu

Lue Tao 29 Sep 20, 2022
Weakly Supervised 3D Object Detection from Point Cloud with Only Image Level Annotation

SCCKTIM Weakly Supervised 3D Object Detection from Point Cloud with Only Image-Level Annotation Our code will be available soon. The class knowledge t

1 Nov 12, 2021
Official code repository for "Exploring Neural Models for Query-Focused Summarization"

Query-Focused Summarization Official code repository for "Exploring Neural Models for Query-Focused Summarization" This is a work in progress. Expect

Salesforce 29 Dec 18, 2022
[ICRA2021] Reconstructing Interactive 3D Scene by Panoptic Mapping and CAD Model Alignment

Interactive Scene Reconstruction Project Page | Paper This repository contains the implementation of our ICRA2021 paper Reconstructing Interactive 3D

97 Dec 28, 2022
A High-Quality Real Time Upscaler for Anime Video

Anime4K Anime4K is a set of open-source, high-quality real-time anime upscaling/denoising algorithms that can be implemented in any programming langua

15.7k Jan 06, 2023
This is the official pytorch implementation of Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation(TESKD)

Student Helping Teacher: Teacher Evolution via Self-Knowledge Distillation (TESKD) By Zheng Li[1,4], Xiang Li[2], Lingfeng Yang[2,4], Jian Yang[2], Zh

Zheng Li 9 Sep 26, 2022
This is a demo app to be used in the video streaming applications

MoViDNN: A Mobile Platform for Evaluating Video Quality Enhancement with Deep Neural Networks MoViDNN is an Android application that can be used to ev

ATHENA Christian Doppler (CD) Laboratory 7 Jul 21, 2022
50-days-of-Statistics-for-Data-Science - This repository consist of a 50-day program

50-days-of-Statistics-for-Data-Science - This repository consist of a 50-day program. All the statistics required for the complete understanding of data science will be uploaded in this repository.

komal_lamba 22 Dec 09, 2022
Two-stage CenterNet

Probabilistic two-stage detection Two-stage object detectors that use class-agnostic one-stage detectors as the proposal network. Probabilistic two-st

Xingyi Zhou 1.1k Jan 03, 2023
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 08, 2023
Official implementation for the paper: "Multi-label Classification with Partial Annotations using Class-aware Selective Loss"

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

99 Dec 27, 2022
A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

Segnet is deep fully convolutional neural network architecture for semantic pixel-wise segmentation. This is implementation of http://arxiv.org/pdf/15

Pradyumna Reddy Chinthala 190 Dec 15, 2022
Point-NeRF: Point-based Neural Radiance Fields

Point-NeRF: Point-based Neural Radiance Fields Project Sites | Paper | Primary c

Qiangeng Xu 662 Jan 01, 2023
Semantic segmentation models, datasets and losses implemented in PyTorch.

Semantic Segmentation in PyTorch Semantic Segmentation in PyTorch Requirements Main Features Models Datasets Losses Learning rate schedulers Data augm

Yassine 1.3k Jan 07, 2023
Implementation of SwinTransformerV2 in TensorFlow.

SwinTransformerV2-TensorFlow A TensorFlow implementation of SwinTransformerV2 by Microsoft Research Asia, based on their official implementation of Sw

Phan Nguyen 2 May 30, 2022