Skip to content

Commit

Permalink
Add MMS_FA bundle
Browse files Browse the repository at this point in the history
The new bundle Wav2Vec2FABundle and its instance MMS_FA are added.
They are specialized for forced alignment, and the usage is explained in
the multilingual FA tutorial.
  • Loading branch information
mthrok committed Aug 2, 2023
1 parent 385ce5b commit 3b04c09
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 83 deletions.
9 changes: 9 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,12 @@ @incollection{45611
URL = {https://arxiv.org/abs/1609.09430},
booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
}

@misc{pratap2023scaling,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
year={2023},
eprint={2305.13516},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,92 +207,19 @@ def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_R
# order to verify the alignment quality. Here we first load the model and dictionary.
#

from torchaudio.models import wav2vec2_model

model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.0,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=31,
)


model.load_state_dict(
torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt"
)
)
model.eval()
model.to(device)
from torchaudio.pipelines import MMS_FA

bundle = MMS_FA
model = bundle.get_model(with_star=False)
dictionary = bundle.get_dict(star=None)


def get_emission(waveform):
with torch.inference_mode():
# NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emission, _ = model(waveform)
return torch.log_softmax(emission, dim=-1)


# Construct the dictionary
# '@' represents the OOV token
# <pad> and </s> are fairseq's legacy tokens, which're not used.
# <star> token is omitted as we do not use it in this tutorial
dictionary = {
"<blank>": 0,
"<pad>": 1,
"</s>": 2,
"@": 3,
"a": 4,
"i": 5,
"e": 6,
"n": 7,
"o": 8,
"u": 9,
"t": 10,
"s": 11,
"r": 12,
"m": 13,
"k": 14,
"l": 15,
"d": 16,
"g": 17,
"h": 18,
"y": 19,
"b": 20,
"p": 21,
"w": 22,
"c": 23,
"v": 24,
"j": 25,
"z": 26,
"f": 27,
"'": 28,
"q": 29,
"x": 30,
}


######################################################################
# Before aligning the speech with transcripts, we need to make sure
# the transcripts are already romanized. Here are the BASH commands
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
HUBERT_BASE,
HUBERT_LARGE,
HUBERT_XLARGE,
MMS_FA,
VOXPOPULI_ASR_BASE_10K_DE,
VOXPOPULI_ASR_BASE_10K_EN,
VOXPOPULI_ASR_BASE_10K_ES,
Expand Down Expand Up @@ -77,6 +78,7 @@
"HUBERT_XLARGE",
"HUBERT_ASR_LARGE",
"HUBERT_ASR_XLARGE",
"MMS_FA",
"WAVLM_BASE",
"WAVLM_BASE_PLUS",
"WAVLM_LARGE",
Expand Down
131 changes: 127 additions & 4 deletions torchaudio/pipelines/_wav2vec2/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

from torch.nn import Module

Expand Down Expand Up @@ -146,7 +147,7 @@ def get_labels(
*,
blank: str = "-",
) -> Tuple[str, ...]:
"""The output class labels (only applicable to fine-tuned bundles)
"""The output class labels.
The first is blank token, and it is customizable.
Expand All @@ -159,8 +160,8 @@ def get_labels(
the output class labels.
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
>>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle
>>> bundle.get_labels()
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
""" # noqa: E501
return (blank, *self._labels)
Expand Down Expand Up @@ -1518,3 +1519,125 @@ def _get_state_dict(self, dl_kwargs):
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
""" # noqa: E501


@dataclass
class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
def get_labels(self, star: Optional[str] = "<star>", blank: str = "<blank>") -> Tuple[str, ...]:
"""Get the labels corresponding to the feature dimension of emission.
The first is blank token, and it is customizable.
Args:
star (str or None, optional): Change or disable star token. (default: ``"<star>"``)
blank (str, optional): Change the blank token. (default: ``'-'``)
Returns:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Example
>>> from torchaudio.pipelines import MMS_FA as bundle
>>> bundle.get_labels()
('<blank>', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '<star>')
>>> bundle.get_labels(star=None)
('<blank>', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
""" # noqa: E501
labels = super().get_labels(blank=blank)
return labels if star is None else (*labels, star)

def _get_params_with_star(self):
params = copy.deepcopy(self._params)
params["aux_num_out"] += 1
return params

def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
with_star (bool, optional): If enabled, the last dimension of output layer is
extended by one, which corresponds to `star` token.
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
"""
params = self._get_params_with_star() if with_star else self._params
model = utils._get_model(self._model_type, params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model.eval()
return model

def get_dict(self, star: Optional[str] = "<star>", blank: str = "<blank>") -> Dict[str, int]:
"""Get the mapping from token to index (in emission feature dim)
Args:
star (str or None, optional): Change or disable star token. (default: ``"<star>"``)
blank (str, optional): Change the blank token. (default: ``'-'``)
Returns:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Example
>>> from torchaudio.pipelines import MMS_FA as bundle
>>> bundle.get_dict()
{'<blank>': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '<star>': 28}
>>> bundle.get_dict(star=None)
{'<blank>': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
""" # noqa: E501
return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))}


MMS_FA = Wav2Vec2FABundle(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": True,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.1,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.1,
"aux_num_out": 28,
},
_labels=utils._get_mms_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
MMS_FA.__doc__ = """
Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`.
Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License <https://github.com/facebookresearch/fairseq/tree/100cd91db19bb27277a06a25eb4154c805b10189/examples/mms#license>`__].
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details.
.. note::
Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different.
""" # noqa: E501
44 changes: 43 additions & 1 deletion torchaudio/pipelines/_wav2vec2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,23 @@ def _remove_aux_axes(state_dict, axes):
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])


def _get_state_dict(url, dl_kwargs, remove_axes=None):
def _add_star_dim(state_dict):
w, b = state_dict["aux.weight"], state_dict["aux.bias"]
zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype)
state_dict["aux.weight"] = torch.cat((zeros, w), dim=0)
ones = torch.ones((1,), device=b.device, dtype=b.dtype)
state_dict["aux.bias"] = torch.cat((b, ones), dim=0)


def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
if add_star:
_add_star_dim(state_dict)
return state_dict


Expand Down Expand Up @@ -301,3 +311,35 @@ def _get_it_labels():
"í",
"ï",
)


def _get_mms_labels():
return (
"a",
"i",
"e",
"n",
"o",
"u",
"t",
"s",
"r",
"m",
"k",
"l",
"d",
"g",
"h",
"y",
"b",
"p",
"w",
"c",
"v",
"j",
"z",
"f",
"'",
"q",
"x",
)

0 comments on commit 3b04c09

Please sign in to comment.