Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Overview

Tez: a simple pytorch trainer

NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something doesn't work, please create an issue.

tez (तेज़ / تیز) means sharp, fast & active. This is a simple, to-the-point, library to make your pytorch training easy.

This library is in very early-stage currently! So, there might be breaking changes.

Idea around tez is simple:

  • keep things as simple as possible
  • make it as customizable as possible
  • clean code
  • faster prototyping
  • production ready

Currently, tez supports cpu and gpu training. More coming soon!

Using tez is super-easy. We don't want you to be far away from pytorch. So, you do everything on your own and just use tez to make a few things simpler.

Training using Tez:

  • To train a model, define a dataset and model. The dataset class is the same old class you would write when writing pytorch models.

  • Create your model class. Instead of inheriting from nn.Module, import tez and inherit from tez.Model as shown in the following example.

class MyModel(tez.Model):
    def __init__(self):
        super().__init__()
        .
        .
        # tell when to step the scheduler
        self.step_scheduler_after="batch"

    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}

    def fetch_scheduler(self):
        # create your own scheduler

    def fetch_optimizer(self):
        # create your own optimizer

    def forward(self, ids, mask, token_type_ids, targets=None):
        _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        b_o = self.bert_drop(o_2)
        output = self.out(b_o)

        # calculate loss here
        loss = nn.BCEWithLogitsLoss()(output, targets)

        # calculate the metric dictionary here
        metric_dict = self.monitor_metrics(output, targets)
        return output, loss, metric_dict

Everything is super-intuitive!

  • Now you can train your model!
# init datasets
train_dataset = SomeTrainDataset()
valid_dataset = SomeValidDataset()

# init model
model = MyModel()


# init callbacks, you can also write your own callback
es = tez.callbacks.EarlyStopping(monitor="valid_loss", model_path="model.bin")

# train model. a familiar api!
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    device="cuda",
    epochs=50,
    callbacks=[es],
    fp16=True,
)

# save model (with optimizer and scheduler for future!)
model.save("model.bin")

You can checkout examples in examples/

Comments
  • ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)

    Could you please help here.

    Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth 100% 74.4M/74.4M [00:00<00:00, 107MB/s]

    Loaded pretrained weights for efficientnet-b4 0%| | 0/51 [00:00<?, ?it/s]

    ValueError Traceback (most recent call last) in () 11 epochs=10, 12 callbacks=[es], ---> 13 fp16=True, 14 ) 15 model.save("model.bin")

    6 frames /usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16) 295 self.train_state = enums.TrainingState.EPOCH_START 296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START --> 297 train_loss = self.train_one_epoch(self.train_loader, device) 298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END 299 if self.valid_loader:

    /usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device) 176 losses = AverageMeter() 177 tk0 = tqdm(data_loader, total=len(data_loader)) --> 178 for b_idx, data in enumerate(tk0): 179 self.train_state = enums.TrainingState.TRAIN_STEP_START 180 loss, metrics = self.train_one_step(data, device)

    /usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self) 1102 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1103 -> 1104 for obj in iterable: 1105 yield obj 1106 # Update and possibly print the progressbar.

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self):

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113

    /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430

    ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem augmented = self.augmentations(image=image) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call data = t(**data) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call res[key] = target_function(arg, **dict(params, **target_dependencies)) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply return F.normalize(image, self.mean, self.std, self.max_pixel_value) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize img -= mean ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    opened by nvnvashisth 10
  • zero_grad for accumulation_steps = 1 not working as expected

    zero_grad for accumulation_steps = 1 not working as expected

    As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

    I tried to reproduce it and it is doing the same which I explained above.

    image

    Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

    opened by abdurrehman11 9
  • Can it work without CUDA

    Can it work without CUDA

    I am getting error when I executed the code with CPU configuration.

    Traceback (most recent call last): File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 88, in train() File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 82, in train model.fit( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 309, in fit self._init_model( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 93, in _init_model self.to(self.device) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 852, in to return self._apply(convert) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 530, in _apply module._apply(fn) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 552, in apply param_applied = fn(param) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 850, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\cuda_init.py", line 166, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled

    opened by hemanthh17 7
  • Documentation improvement - How is tez faster?

    Documentation improvement - How is tez faster?

    Great to see a nice Pytorch training library.

    I think it would help users use it maybe to show what kind of performance improvements come from the box with Tez. For example comparing how fp16 is enabled in tez vs vanilla pytorch could ben informative or just a quick list of optimisations that are easy to do with Tez such as fp16.

    opened by swartchris8 5
  • Is it possible to set variable Lr per epoch

    Is it possible to set variable Lr per epoch

    @abhishekkrthakur Was finding this framework great and easy to use . But as fairly new to it was thinking if there is a way to pass variable Lr for training say for every epoch as an example.

    Also is there a way to say continue training from a particular epoch if say the local system crashed or got disturbed during the training process.

    opened by gauravbrills 3
  • Applying metrics after the epoch

    Applying metrics after the epoch

    Dears, I am using tez to classify melanoma images (kaggle SIIM binary classification). With wtfml is possible to get AUC ~ 0.85. With tez, I am only getting AUC ~ 0.6. I saw that this happens, in tez, when using metrics.roc_auc_score(...) inside monitor_metrics method. This gives some ValueError exceptions, that must be handled returning auc = 0.5 (this error occurs when the data have only 1 class).

    In the wtfml, the metrics.roc_auc_score(...) method is used only after Engine.evaluate. In this case, the data always have two classes (because the KStratified gives that).

    I am wondering if it is possible, in tez, to apply the metrics.roc_auc_score(...) only after the epoch, and not in each train_bs. With that, the data always will have two classes, avoiding the ValueError exceptions.

    PS.

    1. In the class definition init I am using: self.step_scheduler_after = "epoch" self.step_scheduler_metric = "valid_auc"
    2. In the monitor_metrics method: try: auc = metrics.roc_auc_score(targets, outputs.ravel()) except ValueError: auc = 0.5 return {"auc": auc}
    3. My model.fit is defined as: model.fit(train_dataset, valid_dataset=valid_dataset, train_bs=32, valid_bs=16, device="cuda", epochs=50, callbacks=[es], fp16=False, n_jobs=2)
    opened by waldcarl 2
  • Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    this problem occur due to running metric calculation

    I got the solution from stackoverflow:

    You cannot have an ROC curve without both positive and negative examples in your dataset. With only one class in the dataset, you cannot measure your false-positive rate, and therefore cannot plot an ROC curve. This is why you get this error message.

    How to handle this problem?

    opened by IamSantoshKumar 2
  • Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    /usr/local/lib/python3.7/dist-packages/torch/cuda/amp/grad_scaler.py:116: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") 0%| | 0/2939 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py:118: UserWarning: torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.")

    TypeError Traceback (most recent call last) in () 143 epochs=3, 144 callbacks=[tb_logger, es], --> 145 fp16=True, 146 ) 147 model.save("model.bin")

    8 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace) 1074 if p < 0.0 or p > 1.0: 1075 raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) -> 1076 return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) 1077 1078

    TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    opened by gokulguptanew 1
  • Text classification examples - Tokenizer is defined twice

    Text classification examples - Tokenizer is defined twice

    The tokenizer is defined both in the model and the dataset in the BERT text classification examples.

    multi_class.py, line 50: self.tokenizer = transformers.BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True )

    opened by obesp 1
  • Small error in image_classification.py

    Small error in image_classification.py

    If augmentation is None then we face error as , variable augmented referenced before assignment UnboundLocalError: local variable 'augmented' referenced before assignment

    elif self.backend == "cv2":
                image = cv2.imread(self.image_paths[item])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.resize is not None:
                    image = cv2.resize(
                        image,
                        (self.resize[1], self.resize[0]),
                        interpolation=cv2.INTER_CUBIC,
                    )
                if self.augmentations is not None:
                    augmented = self.augmentations(image=image)
                    image = augmented["image"]
    

    If the indendation is fixed we can solve this error.

    opened by VpkPrasanna 1
  • Small error in model.py

    Small error in model.py

    Hi! Love this library.
    In tez/model/model.py there is probably a mistake in line 90:

    self.train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=train_bs,
                    num_workers=n_jobs,
                    sampler=valid_sampler,
                    shuffle=True,
                )
    

    I guess train_sampler is meant to be used here, not valid_sampler.

    opened by hocop 1
  • run example code error

    run example code error

    when I run example code:

    accelerate launch   imdb_sentiment_classification.py
    

    after run some epoch get error info

    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 4/5
    [train] accuracy=0.9915, loss=0.0269 [valid] accuracy=0.8953, loss=0.4287 [e=5 steps=2112]                                                                                                 
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:45<06:40, 12.32it/s, accuracy=0.991, epoch=5, loss=0.0269]2022-09-17 07:55:02,832 INFO EarlyStopping counter: 5/5
    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 5/5
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:47<13:31,  6.07it/s, accuracy=0.991, epoch=5, loss=0.0269]
    
    
    
    
    [E ProcessGroupNCCL.cpp:719] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 113654 closing signal SIGTERM
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 2 (pid: 113655) of binary: /root/miniconda3/envs/lightning/bin/python
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    =======================================================
    imdb_sentiment_classification.py FAILED
    -------------------------------------------------------
    Failures:
    [1]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 3 (local_rank: 3)
      exitcode  : -6 (pid: 113656)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113656
    -------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 2 (local_rank: 2)
      exitcode  : -6 (pid: 113655)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113655
    =======================================================
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/accelerate", line 33, in <module>
        sys.exit(load_entry_point('accelerate==0.12.0.dev0', 'console_scripts', 'accelerate')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/accelerate_cli.py", line 43, in main
        args.func(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 734, in launch_command
        multi_gpu_launcher(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 374, in multi_gpu_launcher
        raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
    subprocess.CalledProcessError: Command '['torchrun', '--nproc_per_node', '4', 'imdb_sentiment_classification.py']' returned non-zero exit status 1.
    
    opened by bestpredicts 0
  • Getting error while importing enums from tez.

    Getting error while importing enums from tez.

    Traceback (most recent call last): File "/content/tez/tez/model/model.py", line 12, in from tez import enums File "/content/tez/tez/model/tez.py", line 11, in from tez import enums ImportError: cannot import name 'enums' from 'tez' (/content/tez/tez/model/tez.py)

    Waiting for positive reply.

    opened by VikasRathod314 3
  • Saving validation score

    Saving validation score

    Is it possible to save somehow a list of the validation scores (on epochs or batches) after training? I have some problems with output on my server, it deletes usually, but I really need validation scores to compare models, it would be really convenient, if I could get them in one file, for example.

    opened by 25icecreamflavors 0
  • Saving after training an epoch

    Saving after training an epoch

    How to save the model after each epoch training? I use fit method for 5 epochs and do not really understand hot to save after each one. not only after the last one.

    opened by 25icecreamflavors 2
  • How can we access the input_ids/attention mask in each train batch loop?

    How can we access the input_ids/attention mask in each train batch loop?

    I tried using a train step callback but I am not sure how to get access to the dataloader input_ids and attention mask during each train step. Is this possible?

    BTW Thanks for the library!

    opened by tkmaker 0
Releases(v0.1.8)
Owner
abhishek thakur
Kaggle: www.kaggle.com/abhishek
abhishek thakur
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 08, 2023
270 Dec 24, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
High-level batteries-included neural network training library for Pytorch

Pywick High-Level Training framework for Pytorch Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with st

382 Dec 06, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 05, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022