Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Overview

Enformer - Pytorch (wip)

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow sonnet code can be found here.

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
Comments
  • Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Hi @lucidrains,

    Thank you so much for your efforts in releasing the PyTorch version of the Enformer model! I am really excited to use it for my particular implementation.

    I was wondering if it is possible to use the pre-trained huggingface model to just get the human output head. The reason is that inference takes a few minutes, and since I just need human data, this will help make my implementation a bit smoother. Is there a way to do this elegantly with the current codebase, or would I need to rewrite some functions to allow for this? From what I have seen so far it doesn't seem that this modularity is possible yet.

    The way I have set up my inference currently is as follows:

    class EnformerInference:
        def __init__(self, data_path: str, model_path="EleutherAI/enformer-official-rough"):
            if torch.cuda.is_available():
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")
            self.device = device
            self.model = Enformer.from_pretrained(model_path)
            self.data = EnformerDataLoader(pd.read_csv(data_path, sep="\t")) # returns a one hot encoded torch.Tensor representation of the sequence of interest
                                                                                                                              
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.model(x.to(self.device))
    

    Any guidance on this would be greatly appreciated, thank you!

    opened by aaronwtr 4
  • Host weights on HuggingFace hub

    Host weights on HuggingFace hub

    Hi Phil Wang,

    Created a little demo on how you can easily load pre-trained weights from the HuggingFace hub into your Enformer model. I've basically followed this guide which Sylvain (@sgugger) wrote recently. It's a new feature that let's you push model weights to the hub and allows to load them into any custom PyTorch/TF/Flax model.

    From this PR, you can do (after pip install enformer-pytorch):

    from enformer_pytorch import Enformer
    
    model = Enformer.from_pretrained("nielsr/enformer-preview")
    

    If you consent, then I'll transfer all weights to the eleutherai organization on the hub, such that you can do from_pretrained("eleutherai/enformer-preview").

    The weights are hosted here: https://huggingface.co/nielsr/enformer-preview. As you can see in the "files and versions" tab, it contains a pytorch_model.bin file, which has a size of about 1GB. You can also load the other variant, as follows:

    model = Enformer.from_pretrained("nielsr/enformer-corr_coef_obj")
    

    To make it work, the only thing that is required is encapsulating all hyperparameters regarding the model architecture into a separate EnformerConfig object (which I've defined in config_enformer.py). It can be instantiated as follows:

    from enformer_pytorch import EnformerConfig
    
    config = EnformerConfig(
        dim = 1536,
        depth = 11,
        heads = 8,
        output_heads = dict(human = 5313, mouse = 1643),
        target_length = 896,
    )
    

    To initialize an Enformer model with randomly initialized weights, you can do:

    from enformer_pytorch import Enformer
    
    model = Enformer(config)
    

    There's no need for the config.yml and model_loader.py files anymore, as these are now handled by HuggingFace :)

    Let me know what you think about it :)

    Kind regards,

    Niels

    To do:

    • [x] upload remaining checkpoints to the hub
    • [x] transfer checkpoints to the eleutherai organization
    • [x] remove config.yml and model_loading.py scripts
    • [x] update README
    opened by NielsRogge 4
  • Minor potential typo in `FastaInterval` class

    Minor potential typo in `FastaInterval` class

    Hello, first off thanks so much for this incredible repository, it's greatly accelerating a project I am working on!

    I've been using the GenomeIntervalDataset class and notice a minor potential typo in the FastaInterval class when I was trying to fetch a sequence with a negative start position and got an empty tensor back. It looks like there is logic for clipping the start position at 0 and padding the sequence here https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#L137-L143 but that it wasn't being used in my case as it was inside the above if clause that I wasn't triggering https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#LL128C9-L128C82. If I unindent that logic then everything worked fine for me.

    If it was unintentional to have the clipping inside that if clause I'd be happy to submit a trivial PR to fix the indentation.

    Thanks again for all your work

    opened by sofroniewn 2
  • example data files

    example data files

    Hi, in the README, you mentioned to use sequences.bed and hg38.ml.fa files to build the GenomeIntervalDataset, but I can't find these example data files, could you provide the links of these files ? Thanks!

    opened by yingyuan830 2
  • Why do we need Residual here while we have residual connection inside conv block

    Why do we need Residual here while we have residual connection inside conv block

    we wrap conv block inside Residual: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L318

    while we have residual connection already inside conv block here: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L226

    opened by inspirit 2
  • Add base_model_prefix

    Add base_model_prefix

    This PR fixes the from_pretrained method by adding base_model_prefix, as this makes sure weights are properly loaded from the hub.

    Kudos to @sgugger for finding the bug.

    opened by NielsRogge 2
  • How to load the pre-trained Enfromer model?

    How to load the pre-trained Enfromer model?

    Hi, I encountered a problem when trying to load the pre-trained enformer model.

    from enformer_pytorch import Enformer model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError Traceback (most recent call last) Input In [3], in 1 from enformer_pytorch import Enformer ----> 2 model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError: type object 'Enformer' has no attribute 'from_pretrained'

    opened by yzJiang9 2
  • enformer TF pretrained weights

    enformer TF pretrained weights

    Hello!

    Thanks for this wonderful resource. I was wondering whether you can point me to how to obtain the model weights for the original TF version of Enformer, or the actual weights if they are stored somewhere easily accessible. I see the model on TF hub but am not sure exactly how to extract the weights - I seem to be running into some issues potentially because the original code is sonnet based and the model is always loaded as a custom user object..

    Much appreciated!

    opened by naumanjaved 1
  • AttentionPool bug?

    AttentionPool bug?

    Looking at the attention pool class did you mean to have

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = self.pool_size)
    

    instead of

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
    

    Here's the full class

    class AttentionPool(nn.Module):
        def __init__(self, dim, pool_size = 2):
            super().__init__()
            self.pool_size = pool_size
            self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
            self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
    
        def forward(self, x):
            b, _, n = x.shape
            remainder = n % self.pool_size
            needs_padding = remainder > 0
    
            if needs_padding:
                x = F.pad(x, (0, remainder), value = 0)
                mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
                mask = F.pad(mask, (0, remainder), value = True)
    
            x = self.pool_fn(x)
            logits = self.to_attn_logits(x)
    
            if needs_padding:
                mask_value = -torch.finfo(logits.dtype).max
                logits = logits.masked_fill(self.pool_fn(mask), mask_value)
    
            attn = logits.softmax(dim = -1)
    
            return (x * attn).sum(dim = -1)
    
    opened by cmlakhan 1
  • Colab notebook for computing the correlation across different basenji2 dataset splits.

    Colab notebook for computing the correlation across different basenji2 dataset splits.

    New features:

    1. Colab notebook for computing correlations across the different basenji2 dataset splits.
    2. Pytorch metric for computing the mean of per-channel correlations properly aggregated across a region set.
    opened by jstjohn 0
  • Computing Contribution Scores

    Computing Contribution Scores

    From the paper:

    To better understand what sequence elements Enformer is utilizing when making predictions, we computed two different gene expression contribution scores — input gradients (gradient × input and attention weights

    I was just wondering how to compute input gradients and fetch the attention matrix for the given input. I'm not well versed with PyTorch, so I'm sorry if this is a noob question.

    opened by Prakash2403 0
  • Models in training splits

    Models in training splits

    Hey,

    Is there a way of getting the models trained in each training set, as mentioned in the "Model training and evaluation" paragraph of the Enformer paper?

    Thanks!

    opened by luciabarb 0
  • metric for enformer

    metric for enformer

    Hello, can I ask how you find of the human pearson R is 0.625 for validation, and 0.65 for test? Couldn't find any information in the paper. Is there any other place that records this?

    opened by Rachel66666 0
  • error loading enformer package

    error loading enformer package

    I am trying to install the enformer package but seem to be getting the following error:

    >>> import torch
    >>> from enformer_pytorch import Enformer
    Traceback (most recent call last):
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 905, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/importlib/__init__.py", line 127, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/modeling_utils.py", line 76, in <module>
        from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
    ImportError: cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/__init__.py", line 2, in <module>
        from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/modeling_enformer.py", line 14, in <module>
        from transformers import PreTrainedModel
      File "<frozen importlib._bootstrap>", line 1055, in _handle_fromlist
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 895, in __getattr__
        module = self._get_module(self._class_to_module[name])
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 907, in _get_module
        raise RuntimeError(
    RuntimeError: Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
    cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    

    I simply cloned an existing pytorch environment on Conda (using cuda 11.1 and torch 1.10) and then pip installed the hugging face packages and enformer packages

    pip install transformers
    pip install datasets
    pip install accelerate
    pip install tokenizers
    pip install enformer-pytorch
    

    Any idea why I'm getting this error?

    opened by cmlakhan 1
Releases(0.5.6)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Survival analysis in Python

What is survival analysis and why should I learn it? Survival analysis was originally developed and applied heavily by the actuarial and medical commu

Cameron Davidson-Pilon 2k Jan 08, 2023
ParmeSan: Sanitizer-guided Greybox Fuzzing

ParmeSan: Sanitizer-guided Greybox Fuzzing ParmeSan is a sanitizer-guided greybox fuzzer based on Angora. Published Work USENIX Security 2020: ParmeSa

VUSec 158 Dec 31, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements 环境安装 pip install -r requirements.txt Evaluation metric Visdrone Model mAP ZhangYuan 719 Jan 02, 2023

A tensorflow model that predicts if the image is of a cat or of a dog.

Quick intro Hello and thank you for your interest in my project! This is the backend part of a two-repo application. The other part can be found here

Tudor Matei 0 Mar 08, 2022
ICLR 2021, Fair Mixup: Fairness via Interpolation

Fair Mixup: Fairness via Interpolation Training classifiers under fairness constraints such as group fairness, regularizes the disparities of predicti

Ching-Yao Chuang 49 Nov 22, 2022
INSPIRED: A Transparent Dialogue Dataset for Interactive Semantic Parsing

INSPIRED: A Transparent Dialogue Dataset for Interactive Semantic Parsing Existing studies on semantic parsing focus primarily on mapping a natural-la

7 Aug 22, 2022
Codebase for Attentive Neural Hawkes Process (A-NHP) and Attentive Neural Datalog Through Time (A-NDTT)

Introduction Codebase for the paper Transformer Embeddings of Irregularly Spaced Events and Their Participants. This codebase contains two packages: a

Alan Yang 28 Dec 12, 2022
A simple Tensorflow based library for deep and/or denoising AutoEncoder.

libsdae - deep-Autoencoder & denoising autoencoder A simple Tensorflow based library for Deep autoencoder and denoising AE. Library follows sklearn st

Rajarshee Mitra 147 Nov 18, 2022
A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swar.

Omni-swarm A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swarm Introduction Omni-swarm is a decentralized omn

HKUST Aerial Robotics Group 99 Dec 23, 2022
Codebase for INVASE: Instance-wise Variable Selection - 2019 ICLR

Codebase for "INVASE: Instance-wise Variable Selection" Authors: Jinsung Yoon, James Jordon, Mihaela van der Schaar Paper: Jinsung Yoon, James Jordon,

Jinsung Yoon 50 Nov 11, 2022
Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Elias Kassapis 31 Nov 22, 2022
Repo for the Tutorials of Day1-Day3 of the Nordic Probabilistic AI School 2021 (https://probabilistic.ai/)

ProbAI 2021 - Probabilistic Programming and Variational Inference Tutorial with Pryo Day 1 (June 14) Slides Notebook: students_PPLs_Intro Notebook: so

PGM-Lab 46 Nov 01, 2022
GLANet - The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv

GLANet The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv Framework: visualization results: Getting Starte

stanley 29 Dec 14, 2022
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Salesforce 72 Dec 05, 2022
TLoL (Python Module) - League of Legends Deep Learning AI (Research and Development)

TLoL-py - League of Legends Deep Learning Library TLoL-py is the Python component of the TLoL League of Legends deep learning library. It provides a s

7 Nov 29, 2022
unet-family: Ultimate version

unet-family: Ultimate version 基于之前my-unet代码,我整理出来了这一份终极版本unet-family,方便其他人阅读。 相比于之前的my-unet代码,代码分类更加规范,有条理 对于clone下来的代码不需要修改各种复杂繁琐的路径问题,直接就可以运行。 并且代码有

2 Sep 19, 2022
git《Tangent Space Backpropogation for 3D Transformation Groups》(CVPR 2021) GitHub:1]

LieTorch: Tangent Space Backpropagation Introduction The LieTorch library generalizes PyTorch to 3D transformation groups. Just as torch.Tensor is a m

Princeton Vision & Learning Lab 482 Jan 06, 2023
E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

11 Nov 08, 2022
Single Image Super-Resolution (SISR) with SRResNet, EDSR and SRGAN

Single Image Super-Resolution (SISR) with SRResNet, EDSR and SRGAN Introduction Image super-resolution (SR) is the process of recovering high-resoluti

8 Apr 15, 2022
The Malware Open-source Threat Intelligence Family dataset contains 3,095 disarmed PE malware samples from 454 families

MOTIF Dataset The Malware Open-source Threat Intelligence Family (MOTIF) dataset contains 3,095 disarmed PE malware samples from 454 families, labeled

Booz Allen Hamilton 112 Dec 13, 2022