diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst index fd9205a5fa..f2e8803552 100644 --- a/docs/source/pipelines.rst +++ b/docs/source/pipelines.rst @@ -27,7 +27,7 @@ RNN-T Streaming/Non-Streaming ASR --------------------------------- Interface -^^^^^^^^^ +~~~~~~~~~ ``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization. @@ -47,7 +47,7 @@ Interface .. minigallery:: torchaudio.pipelines.RNNTBundle Pretrained Models -^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated @@ -61,7 +61,7 @@ wav2vec 2.0 / HuBERT / WavLM - SSL ---------------------------------- Interface -^^^^^^^^^ +~~~~~~~~~ ``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning. @@ -75,7 +75,7 @@ Interface Wav2Vec2Bundle Pretrained Models -^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated @@ -100,7 +100,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR ------------------------------------- Interface -^^^^^^^^^ +~~~~~~~~~ ``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR. @@ -118,7 +118,7 @@ Interface .. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle Pretrained Models -^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated @@ -157,7 +157,7 @@ Tacotron2 Text-To-Speech Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`. Interface -^^^^^^^^^ +~~~~~~~~~ .. autosummary:: :toctree: generated @@ -173,7 +173,7 @@ Interface .. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle Pretrained Models -^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated @@ -189,7 +189,7 @@ Source Separation ----------------- Interface -^^^^^^^^^ +~~~~~~~~~ ``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio. @@ -207,7 +207,7 @@ Interface .. minigallery:: torchaudio.pipelines.SourceSeparationBundle Pretrained Models -^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated diff --git a/torchaudio/pipelines/_wav2vec2/impl.py b/torchaudio/pipelines/_wav2vec2/impl.py index 87d71258bf..6a8faf1127 100644 --- a/torchaudio/pipelines/_wav2vec2/impl.py +++ b/torchaudio/pipelines/_wav2vec2/impl.py @@ -1,41 +1,12 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Tuple -import torch -from torch import Tensor -from torch.nn import functional as F, Module -from torchaudio._internal import load_state_dict_from_url -from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model +from torch.nn import Module from . import utils -__all__ = [] - - -class _Wav2Vec2Model(Module): - """Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`. - - This is used for layer normalization at the input - """ - - def __init__(self, model: Wav2Vec2Model): - super().__init__() - self.model = model - - def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: - waveforms = F.layer_norm(waveforms, waveforms.shape) - return self.model(waveforms, lengths) - - @torch.jit.export - def extract_features( - self, - waveforms: Tensor, - lengths: Optional[Tensor] = None, - num_layers: Optional[int] = None, - ) -> Tuple[List[Tensor], Optional[Tensor]]: - waveforms = F.layer_norm(waveforms, waveforms.shape) - return self.model.extract_features(waveforms, lengths, num_layers) +__all__ = [] # type: ignore @dataclass @@ -84,10 +55,8 @@ def sample_rate(self) -> float: return self._sample_rate def _get_state_dict(self, dl_kwargs): - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(url, **dl_kwargs) - return state_dict + # Note: This method is overridden in ASR bundle + return utils._get_state_dict(self._path, dl_kwargs) def get_model(self, *, dl_kwargs=None) -> Module: """Construct the model and load the pretrained weight. @@ -119,13 +88,11 @@ def get_model(self, *, dl_kwargs=None) -> Module: - HUBERT_ASR_XLARGE - WAVLM_LARGE """ - if self._model_type == "WavLM": - model = wavlm_model(**self._params) - else: - model = wav2vec2_model(**self._params) - model.load_state_dict(self._get_state_dict(dl_kwargs)) + model = utils._get_model(self._model_type, self._params) + state_dict = self._get_state_dict(dl_kwargs) + model.load_state_dict(state_dict) if self._normalize_waveform: - model = _Wav2Vec2Model(model) + model = utils._apply_input_layer_norm(model) model.eval() return model @@ -171,14 +138,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): >>> transcripts = ctc_decode(emissions, labels) """ # noqa: E501 - _labels: Tuple[str] - _remove_aux_axis: Tuple[int] = (1, 2, 3) + _labels: Tuple[str, ...] + _remove_aux_axis: Tuple[int, ...] = (1, 2, 3) def get_labels( self, *, blank: str = "-", - ) -> Tuple[str]: + ) -> Tuple[str, ...]: """The output class labels (only applicable to fine-tuned bundles) The first is blank token, and it is customizable. @@ -187,7 +154,7 @@ def get_labels( blank (str, optional): Blank token. (default: ``'-'``) Returns: - Tuple[str]: + Tuple[str, ...]: For models fine-tuned on ASR, returns the tuple of strings representing the output class labels. @@ -199,23 +166,7 @@ def get_labels( return (blank, *self._labels) def _get_state_dict(self, dl_kwargs): - state_dict = super()._get_state_dict(dl_kwargs) - if self._remove_aux_axis: - # Remove the seemingly unnecessary axis - # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3 - # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks, - # but not used during the ASR training. - # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37 - # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129 - # - # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and - # that resembles mistake. - # The label `1` shows up in the training dataset of German (1 out of 16M), - # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M) - for key in ["aux.weight", "aux.bias"]: - t = state_dict[key] - state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis]) - return state_dict + return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis) WAV2VEC2_BASE = Wav2Vec2Bundle( diff --git a/torchaudio/pipelines/_wav2vec2/utils.py b/torchaudio/pipelines/_wav2vec2/utils.py index 60f76b8007..0ab459f34a 100644 --- a/torchaudio/pipelines/_wav2vec2/utils.py +++ b/torchaudio/pipelines/_wav2vec2/utils.py @@ -1,3 +1,80 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn, Tensor + +from torchaudio._internal import load_state_dict_from_url +from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model + + +def _get_model(type_, params): + factories = { + "Wav2Vec2": wav2vec2_model, + "WavLM": wavlm_model, + } + if type_ not in factories: + raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}") + factory = factories[type_] + return factory(**params) + + +class _Wav2Vec2Model(nn.Module): + """Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`. + + This is used for layer normalization at the input + """ + + def __init__(self, model: Wav2Vec2Model): + super().__init__() + self.model = model + + def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) + return self.model(waveforms, lengths) + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) + return self.model.extract_features(waveforms, lengths, num_layers) + + +def _apply_input_layer_norm(module): + """Add extra layer_norm to the model""" + return _Wav2Vec2Model(module) + + +def _remove_aux_axes(state_dict, axes): + # Remove the seemingly unnecessary axis + # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3 + # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks, + # but not used during the ASR training. + # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37 + # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129 + # + # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and + # that resembles mistake. + # The label `1` shows up in the training dataset of German (1 out of 16M), + # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M) + for key in ["aux.weight", "aux.bias"]: + mat = state_dict[key] + state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) + + +def _get_state_dict(url, dl_kwargs, remove_axes=None): + if not url.startswith("https"): + url = f"https://download.pytorch.org/torchaudio/models/{url}" + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + if remove_axes: + _remove_aux_axes(state_dict, remove_axes) + return state_dict + + def _get_en_labels(): return ( "|",