WRENCH: Weak supeRvision bENCHmark

Overview

made-with-python Maintenance license repo size Total lines visitors GitHub stars GitHub forks Arxiv

๐Ÿ”ง What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

@misc{zhang2021wrench,
      title={WRENCH: A Comprehensive Benchmark for Weak Supervision}, 
      author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
      year={2021},
      eprint={2109.11377},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

๐Ÿ”ง What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

To track recent advances in weak supervision, please follow this repo.

๐Ÿ”ง Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

๐Ÿ”ง Available Datasets

The datasets can be downloaded via this.

or via command line

pip install gdown 
gdown https://drive.google.com/uc?id=19wMFmpoo_0ORhBzB6n16B1nRRX508AnJ
unzip datasets.zip
rm datasets.zip 

A documentation of dataset format and usage can be found in this wiki-page

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

๐Ÿ”ง Available Models

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
BERT End Model link link
COSINE End Model link link
Denoise Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
BERT-CRF End Model link link
LSTM-ConNet Joint Model link link
BERT-ConNet Joint Model link link

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

๐Ÿ”ง Quick examples

๐Ÿ”ง Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

For detailed guidance of grid_search, please check out this wiki page.

๐Ÿ”ง Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

๐Ÿ”ง Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

๐Ÿ”ง Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')

๐Ÿ”ง Contact

Contact person: Jieyu Zhang, [email protected]

Don't hesitate to send us an e-mail if you have any question.

We're also open to any collaboration!

๐Ÿ”ง Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!

Comments
  • ModuleNotFoundError: No module named 'tokenizations'

    ModuleNotFoundError: No module named 'tokenizations'

    Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

    opened by Gnaiqing 12
  • Is there a limitation of using dataset for different algs?

    Is there a limitation of using dataset for different algs?

    Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

        loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
    KeyError: 'labels'
    

    Can this be fixed?

    opened by mrbeann 8
  • Python Package Installation Fails

    Python Package Installation Fails

    Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

    Tested on system:

    • OS: ubuntu
    • Python: 3.8.13
    • Clean VE

    Command to replicate:

    • pip install ws-benchmark

    Stack Trace:

    ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        ws-benchmark 1.1.1 depends on networkx==2.7
        snorkel 0.9.7 depends on networkx<2.4 and >=2.2
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    
    opened by bradleyfowler123 4
  • Using Multiple GPUs

    Using Multiple GPUs

    Hi,

    Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

    Best regards.

    opened by tolgayan 4
  • Running scripts

    Running scripts

    Hi, I am trying to run some models on the IMDB dataset.

    MLP:

    import logging
    import torch
    import numpy as np
    from wrench.dataset import load_dataset
    from wrench.labelmodel import Snorkel
    from wrench.logging import LoggingHandler
    from wrench.search import grid_search
    from wrench.endmodel import EndClassifierModel
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
        #### Search Space
        search_space = {
            'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
            'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
        }
    
        #### Initialize the model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam'
        )
    
        #### Search best hyper-parameters using validation set in parallel
        n_trials = 20
        n_repeats = 1
        searched_paras = grid_search(
            model,
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            direction='auto',
            search_space=search_space,
            n_repeats=n_repeats,
            n_trials=n_trials,
            parallel=True,
            device=device,
        )
    
    
        #### Run end model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam',
            **searched_paras
        )
        model.fit(
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            device=device
        )
    
        logger.info(model.predict(test_data).tolist())
    
        acc = model.test(test_data, 'acc')
        logger.info(f'end model (MLP) test acc: {acc}')
    
    

    for which I am getting the following output:

    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [00:00<00:00, 902651.16it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 852639.45it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 829503.99it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [1:42:45<00:00,  3.24it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:24<00:00,  3.11it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:50<00:00,  3.01it/s]
    [I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
    2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:24:36 - label model test acc: 0.716
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:37<00:00, 37.48s/it]
    [I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:23<00:00, 23.70s/it]
    [I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:14<00:00, 14.53s/it]
    [I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:43<00:00, 43.73s/it]
    [I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:18<00:00, 18.85s/it]
    [I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:38<00:00, 38.81s/it]
    [I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:38<00:00, 38.15s/it]
    [I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.47s/it]
    [I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:22<00:00, 22.49s/it]
    [I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:40<00:00, 40.25s/it]
    [I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:39<00:00, 39.06s/it]
    [I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:43<00:00, 43.46s/it]
    [I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:23<00:00, 23.41s/it]
    [I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:22<00:00, 22.12s/it]
    [I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.78s/it]
    [I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:37<00:00, 37.28s/it]
    [I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:16<00:00, 16.04s/it]
    [I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:19<00:00, 19.65s/it]
    [I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:15<00:00, 15.41s/it]
    [I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:16<00:00, 16.75s/it]
    [I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
    [TRAIN]:  15%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
    2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
    2021-10-23 22:33:43 - 
    ==========[hyper parameters]==========
    {
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.1,
            "weight_decay": 0.1
        }
    }
    ==========[backbone config]==========
    {
        "name": "MLP",
        "paras": {
            "hidden_size": 100,
            "dropout": 0.0
        }
    }
    
    2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
    2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004
    

    COSINE:

    import logging
    import torch
    from wrench.dataset import load_dataset
    from wrench.logging import LoggingHandler
    from wrench.labelmodel import Snorkel
    from wrench.endmodel import Cosine
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
    
        # COSINE
        model = Cosine(
            teacher_update=100,
            margin=1.0,
            thresh=0.6,
            lr=1e-5,
            mu=1.0,
            lamda=0.05,
            backbone='BERT',
            backbone_model_name=bert_model_name,
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
        )
    
        model.fit(dataset_train=train_data,
                  dataset_valid=valid_data,
                  y_train=aggregated_hard_labels,
                  evaluation_step=10,
                  metric='acc',
                  patience=50,
                  device=device)
    
        acc = model.test(test_data, 'acc')
    
        logger.info(model.predict(test_data))
    
        logger.info(f'end model (COSINE) test acc: {acc}')
    

    for which I am getting the following output:

    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [00:00<00:00, 899119.81it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 423667.07it/s]
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [00:00<00:00, 802645.44it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 20000/20000 [1:47:44<00:00,  3.09it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [14:22<00:00,  2.90it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2500/2500 [13:33<00:00,  3.07it/s] 
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    [TRAIN] COSINE pretrain stage:   5%|โ–Š               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
    [TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
    2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:30:14 - label model test acc: 0.716
    2021-10-23 22:30:17 - 
    ==========[hyper parameters]==========
    {
        "teacher_update": 100,
        "margin": 1.0,
        "mu": 1.0,
        "thresh": 0.6,
        "lamda": 0.05,
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.001,
            "weight_decay": 0.0
        }
    }
    ==========[backbone config]==========
    {
        "name": "BERT",
        "paras": {
            "model_name": "bert-base-cased",
            "max_tokens": 512,
            "fine_tune_layers": -1
        }
    }
    ==========[label model_config config]==========
    {
        "name": "MajorityVoting",
        "paras": {}
    }
    
    2021-10-23 22:51:52 - [INFO] early stop @ step 510!
    2021-10-23 22:55:20 - early stop because all the data are filtered!
    2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
    2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5
    

    As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

    Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

    I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

    Thanks for the benchmark, I really appreciate it!

    opened by viheheb757 4
  • Reproducing Table 11 for classification

    Reproducing Table 11 for classification

    Thanks for this package @JieyuZ2 -- do you happen to have an orchestration script for reproducing Table 11 (and therefore Table 3) in the Wrench paper?

    opened by pmangg 3
  • No module named 'wrench.classification.self_training'

    No module named 'wrench.classification.self_training'

    Hi, I am trying to run run_denoise.py but I am getting the following error:

    Traceback (most recent call last):
      File "run_denoise.py", line 5, in <module>
        from wrench.classification import Denoise
      File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
        from .self_training import LDSelfTrain, DDSelfTrain
    ModuleNotFoundError: No module named 'wrench.classification.self_training'
    

    Could you please add LDSelfTrain and DDSelfTrain classes?

    opened by andreaspung 3
  • Questions on the use of ground-truth labels for validation

    Questions on the use of ground-truth labels for validation

    Thanks for putting up the benchmark! This is really great work! It seems that both the label model and the end model use the ground-truth labels for validation. For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights: https://github.com/JieyuZ2/wrench/blob/544119e781d010797cf153307aa1090361c99522/wrench/basemodel.py#L286 I have a few questions regarding this: (1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models. (2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

    I appreciate any explanations. Thanks!

    opened by wurenzhi 2
  • Clarifying dataset download links

    Clarifying dataset download links

    Great work on the benchmark!

    Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

    One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

    opened by jason-fries 2
  • Fix retained probabilities

    Fix retained probabilities

    This pull request removes a bug which lead to the wrong probabilities being stored along with the predictions of each labeling function.

    Previously, all probabilities (2d tensor of size batch by classes) were saved alongside the class predictions. However, what was supposed to be saved is the probability associated with each prediction of the model.

    opened by benbo 2
  • New Release

    New Release

    Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

    I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

    opened by rsmith49 2
  • Numba 0.43 doesn't work with newer Python versions

    Numba 0.43 doesn't work with newer Python versions

    The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue. Traceback:

    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
      warnings.warn(
    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
      warnings.warn(
    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
        from .dawid_skene import DawidSkene
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
        from numba import njit, prange
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
        from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
        from .targets import registry
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
        from . import cpu
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
        from numba import _dynfunc, config
    ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK
    
    opened by susuzheng 0
  • COSINE for token classification?

    COSINE for token classification?

    Hi,

    I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

    opened by KrishnanJothi 0
  • Balance in Dawid Skene is obsolete.

    Balance in Dawid Skene is obsolete.

    https://github.com/JieyuZ2/wrench/blob/6d8397956533fc6c2fe50e93fcfe0a2303bdd05f/wrench/labelmodel/dawid_skene.py#L55

    I realized this balance variable is used nowhere in this file. If it is intended, I think it should be removed from input parameters.

    opened by ch-shin 1
  • Balance sum to 1

    Balance sum to 1

    https://github.com/JieyuZ2/wrench/blob/ab717ac26a76649c8fdb946a28dffe7e682c80ba/wrench/basemodel.py#L303

    Hi, I find a minor issue that the class prior computed by this function does not sum to 1. Hope you can revise it.

    opened by Gnaiqing 0
  • about COSINE endmodel

    about COSINE endmodel

    Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

    opened by Gnaiqing 0
  • Recommended parameters to use for each algorithms and datasets.

    Recommended parameters to use for each algorithms and datasets.

    I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper. I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

    opened by mrbeann 0
Releases(v1.1)
  • v1.1(Nov 9, 2021)

    What's new:

    • A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
    • A new EndClassifierModel model which unifies all the classification backbones
    • Two new datasets on image classification
    • Support torch native amp for inference in the validation step
    • Support training on multiple GPUS via torch's DistributedDataParallel and the new parallel_fit function
    • fixed some bugs
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Sep 7, 2021)

Owner
Jieyu Zhang
CS PhD
Jieyu Zhang
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
Research using Cirq!

ReCirq Research using Cirq! This project contains modules for running quantum computing applications and experiments through Cirq and Quantum Engine.

quantumlib 230 Dec 29, 2022
Official PyTorch implementation of "The Center of Attention: Center-Keypoint Grouping via Attention for Multi-Person Pose Estimation" (ICCV 21).

CenterGroup This the official implementation of our ICCV 2021 paper The Center of Attention: Center-Keypoint Grouping via Attention for Multi-Person P

Dynamic Vision and Learning Group 43 Dec 25, 2022
A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swar.

Omni-swarm A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swarm Introduction Omni-swarm is a decentralized omn

HKUST Aerial Robotics Group 99 Dec 23, 2022
Cross-platform-profile-pic-changer - Script to change profile pictures across multiple platforms

cross-platform-profile-pic-changer script to change profile pictures across mult

4 Jan 17, 2022
A simple implementation of Kalman filter in Multi Object Tracking

kalman Filter in Multi-object Tracking A simple implementation of Kalman filter in Multi Object Tracking ๆœฌๅฎž็Žฐๆ˜ฏๅœจhttps://github.com/liuchangji/kalman-fil

124 Dec 29, 2022
Repo for Photon-Starved Scene Inference using Single Photon Cameras, ICCV 2021

Photon-Starved Scene Inference using Single Photon Cameras ICCV 2021 Arxiv Project Video Bhavya Goyal, Mohit Gupta University of Wisconsin-Madison Abs

Bhavya Goyal 5 Nov 15, 2022
Official implementation for "QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation" (CVPR 2022)

QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation (CVPR2022) https://arxiv.org/abs/2203.08483 Unpaired image-to-image (I2I

Xueqi Hu 50 Dec 16, 2022
Seasonal Contrast: Unsupervised Pre-Training from Uncurated Remote Sensing Data

Seasonal Contrast: Unsupervised Pre-Training from Uncurated Remote Sensing Data This is the official PyTorch implementation of the SeCo paper: @articl

ElementAI 101 Dec 12, 2022
MediaPipe Kullanarak ฤฐleri Seviye Bilgisayarla Gรถrรผ

MediaPipe Kullanarak ฤฐleri Seviye Bilgisayarla Gรถrรผ

Burak Bagatarhan 12 Mar 29, 2022
Contrastive Learning of Image Representations with Cross-Video Cycle-Consistency

Contrastive Learning of Image Representations with Cross-Video Cycle-Consistency This is a official implementation of the CycleContrast introduced in

13 Nov 14, 2022
OpenDILab Multi-Agent Environment

Go-Bigger: Multi-Agent Decision Intelligence Environment GoBigger Doc (ไธญๆ–‡็‰ˆ) Ongoing 2021.11.13 We are holding a competition โ€”โ€” Go-Bigger: Multi-Agent

OpenDILab 441 Jan 05, 2023
Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )

Differential Privacy (DP) Based Federated Learning (FL) Everything about DP-based FL you need is here. ๏ผˆๆ‰€ๆœ‰ไฝ ้œ€่ฆ็š„DP-based FL็š„ไฟกๆฏ้ƒฝๅœจ่ฟ™้‡Œ๏ผ‰ Code Tip: the code o

wenzhu 83 Dec 24, 2022
Language Models for the legal domain in Spanish done @ BSC-TEMU within the "Plan de las Tecnologรญas del Lenguaje" (Plan-TL).

Spanish legal domain Language Model โš–๏ธ This repository contains the page for two main resources for the Spanish legal domain: A RoBERTa model: https:/

Plan de Tecnologรญas del Lenguaje - Gobierno de Espaรฑa 12 Nov 14, 2022
๐Ÿ”ฅ๐Ÿ”ฅHigh-Performance Face Recognition Library on PaddlePaddle & PyTorch๐Ÿ”ฅ๐Ÿ”ฅ

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 04, 2023
TrackFormer: Multi-Object Tracking with Transformers

TrackFormer: Multi-Object Tracking with Transformers This repository provides the official implementation of the TrackFormer: Multi-Object Tracking wi

Tim Meinhardt 321 Dec 29, 2022
It is the assignment for COMP 576 in Rice University

COMP-576 It is the assignment for COMP 576 in Rice University There are two programming assignments and one Final Project. Assignment 1: It is a MLP a

Maojie Tang 1 Nov 25, 2021
Numerical-computing-is-fun - Learning numerical computing with notebooks for all ages.

As much as this series is to educate aspiring computer programmers and data scientists of all ages and all backgrounds, it is also a reminder to mysel

EKA foundation 758 Dec 25, 2022
GANsformer: Generative Adversarial Transformers Drew A

GANformer: Generative Adversarial Transformers Drew A. Hudson* & C. Lawrence Zitnick Update: We released the new GANformer2 paper! *I wish to thank Ch

Drew Arad Hudson 1.2k Jan 02, 2023
Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

Codebase for Amodal Segmentation through Out-of-Task andOut-of-Distribution Generalization with a Bayesian Model

Yihong Sun 12 Nov 15, 2022