Skip to content

George0828Zhang/torch_cif

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch-cif

A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.

Installation

PyPI

pip install torch-cif

Locally

git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install

Usage

def cif_function(
    inputs: Tensor,
    alpha: Tensor,
    beta: float = 1.0,
    tail_thres: float = 0.5,
    padding_mask: Optional[Tensor] = None,
    target_lengths: Optional[Tensor] = None,
    eps: float = 1e-4,
    unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
    r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
    https://arxiv.org/abs/1905.11235

    Shapes:
        N: batch size
        S: source (encoder) sequence length
        C: source feature dimension
        T: target sequence length

    Args:
        inputs (Tensor): (N, S, C) Input features to be integrated.
        alpha (Tensor): (N, S) Weights corresponding to each elements in the
            inputs. It is expected to be after sigmoid function.
        beta (float): the threshold used for determine firing.
        tail_thres (float): the threshold for determine firing for tail handling.
        padding_mask (Tensor, optional): (N, S) A binary mask representing
            padded elements in the inputs. 1 is padding, 0 is not.
        target_lengths (Tensor, optional): (N,) Desired length of the targets
            for each sample in the minibatch.
        eps (float, optional): Epsilon to prevent underflow for divisions.
            Default: 1e-4
        unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.

    Returns -> Dict[str, List[Tensor]]: Key/values described below.
        cif_out: (N, T, C) The output integrated from the source.
        cif_lengths: (N,) The output length for each element in batch.
        alpha_sum: (N,) The sum of alpha for each element in batch.
            Can be used to compute the quantity loss.
        delays: (N, T) The expected delay (in terms of source tokens) for
            each target tokens in the batch.
        tail_weights: (N,) During inference, return the tail.
        scaled_alpha: (N, S) alpha after applying weight scaling.
        cumsum_alpha: (N, S) cumsum of alpha after scaling.
        right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
        right_weights: (N, S) right scatter weights.
        left_indices: (N, S) left scatter indices.
        left_weights: (N, S) left scatter weights.
    """

Note

  • This implementation uses cumsum and floor to determine the firing positions, and use scatter to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence (0.4, 1.8, 1.2, 1.2, 1.4)

drawing

  • Runing test requires pip install hypothesis expecttest.
  • If beta != 1, our implementation slightly differ from Algorithm 1 in the paper [1]:
    • When a boundary is located, the original algorithm add the last feature to the current integration with weight 1 - accumulation (line 11 in Algorithm 1), which causes negative weights in next integration when alpha < 1 - accumulation.
    • We use beta - accumulation, which means the weight in next integration alpha - (beta - accumulation) is always positive.
  • Feel free to contact me if there are bugs in the code.

References

  1. CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition
  2. Exploring Continuous Integrate-and-Fire for Adaptive Simultaneous Speech Translation