Skip to content

Commit

Permalink
Refactor wav2vec2 pipeline misc helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Aug 2, 2023
1 parent 732c94a commit 385ce5b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 73 deletions.
20 changes: 10 additions & 10 deletions docs/source/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RNN-T Streaming/Non-Streaming ASR
---------------------------------

Interface
^^^^^^^^^
~~~~~~~~~

``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization.

Expand All @@ -47,7 +47,7 @@ Interface
.. minigallery:: torchaudio.pipelines.RNNTBundle

Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
Expand All @@ -61,7 +61,7 @@ wav2vec 2.0 / HuBERT / WavLM - SSL
----------------------------------

Interface
^^^^^^^^^
~~~~~~~~~

``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning.

Expand All @@ -75,7 +75,7 @@ Interface
Wav2Vec2Bundle

Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
Expand All @@ -100,7 +100,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------

Interface
^^^^^^^^^
~~~~~~~~~

``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR.

Expand All @@ -118,7 +118,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle

Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
Expand Down Expand Up @@ -157,7 +157,7 @@ Tacotron2 Text-To-Speech
Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`.

Interface
^^^^^^^^^
~~~~~~~~~

.. autosummary::
:toctree: generated
Expand All @@ -173,7 +173,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle

Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
Expand All @@ -189,7 +189,7 @@ Source Separation
-----------------

Interface
^^^^^^^^^
~~~~~~~~~

``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio.

Expand All @@ -207,7 +207,7 @@ Interface
.. minigallery:: torchaudio.pipelines.SourceSeparationBundle

Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
Expand Down
77 changes: 14 additions & 63 deletions torchaudio/pipelines/_wav2vec2/impl.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Tuple

import torch
from torch import Tensor
from torch.nn import functional as F, Module
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
from torch.nn import Module

from . import utils


__all__ = []


class _Wav2Vec2Model(Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""

def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model

def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)

@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
__all__ = [] # type: ignore


@dataclass
Expand Down Expand Up @@ -84,10 +55,8 @@ def sample_rate(self) -> float:
return self._sample_rate

def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
# Note: This method is overridden in ASR bundle
return utils._get_state_dict(self._path, dl_kwargs)

def get_model(self, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
Expand Down Expand Up @@ -119,13 +88,11 @@ def get_model(self, *, dl_kwargs=None) -> Module:
- HUBERT_ASR_XLARGE
- WAVLM_LARGE
"""
if self._model_type == "WavLM":
model = wavlm_model(**self._params)
else:
model = wav2vec2_model(**self._params)
model.load_state_dict(self._get_state_dict(dl_kwargs))
model = utils._get_model(self._model_type, self._params)
state_dict = self._get_state_dict(dl_kwargs)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = _Wav2Vec2Model(model)
model = utils._apply_input_layer_norm(model)
model.eval()
return model

Expand Down Expand Up @@ -171,14 +138,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501

_labels: Tuple[str]
_remove_aux_axis: Tuple[int] = (1, 2, 3)
_labels: Tuple[str, ...]
_remove_aux_axis: Tuple[int, ...] = (1, 2, 3)

def get_labels(
self,
*,
blank: str = "-",
) -> Tuple[str]:
) -> Tuple[str, ...]:
"""The output class labels (only applicable to fine-tuned bundles)
The first is blank token, and it is customizable.
Expand All @@ -187,7 +154,7 @@ def get_labels(
blank (str, optional): Blank token. (default: ``'-'``)
Returns:
Tuple[str]:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Expand All @@ -199,23 +166,7 @@ def get_labels(
return (blank, *self._labels)

def _get_state_dict(self, dl_kwargs):
state_dict = super()._get_state_dict(dl_kwargs)
if self._remove_aux_axis:
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict
return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)


WAV2VEC2_BASE = Wav2Vec2Bundle(
Expand Down
77 changes: 77 additions & 0 deletions torchaudio/pipelines/_wav2vec2/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,80 @@
from typing import List, Optional, Tuple

import torch
from torch import nn, Tensor

from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model


def _get_model(type_, params):
factories = {
"Wav2Vec2": wav2vec2_model,
"WavLM": wavlm_model,
}
if type_ not in factories:
raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
factory = factories[type_]
return factory(**params)


class _Wav2Vec2Model(nn.Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""

def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model

def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)

@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)


def _apply_input_layer_norm(module):
"""Add extra layer_norm to the model"""
return _Wav2Vec2Model(module)


def _remove_aux_axes(state_dict, axes):
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
mat = state_dict[key]
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])


def _get_state_dict(url, dl_kwargs, remove_axes=None):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
return state_dict


def _get_en_labels():
return (
"|",
Expand Down

0 comments on commit 385ce5b

Please sign in to comment.