From 3b04c090c3aec409bb67ed72ae683b7fa847fceb Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Sun, 30 Jul 2023 18:05:10 -0400 Subject: [PATCH] Add MMS_FA bundle The new bundle Wav2Vec2FABundle and its instance MMS_FA are added. They are specialized for forced alignment, and the usage is explained in the multilingual FA tutorial. --- docs/source/refs.bib | 9 ++ ...lignment_for_multilingual_data_tutorial.py | 83 +---------- torchaudio/pipelines/__init__.py | 2 + torchaudio/pipelines/_wav2vec2/impl.py | 131 +++++++++++++++++- torchaudio/pipelines/_wav2vec2/utils.py | 44 +++++- 5 files changed, 186 insertions(+), 83 deletions(-) diff --git a/docs/source/refs.bib b/docs/source/refs.bib index bac17ee6285..3853bfa919a 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -570,3 +570,12 @@ @incollection{45611 URL = {https://arxiv.org/abs/1609.09430}, booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)} } + +@misc{pratap2023scaling, + title={Scaling Speech Technology to 1,000+ Languages}, + author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli}, + year={2023}, + eprint={2305.13516}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} diff --git a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py index 6f78b0e5d36..52c33901342 100644 --- a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py +++ b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py @@ -207,92 +207,19 @@ def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_R # order to verify the alignment quality. Here we first load the model and dictionary. # -from torchaudio.models import wav2vec2_model - -model = wav2vec2_model( - extractor_mode="layer_norm", - extractor_conv_layer_config=[ - (512, 10, 5), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 2, 2), - (512, 2, 2), - ], - extractor_conv_bias=True, - encoder_embed_dim=1024, - encoder_projection_dropout=0.0, - encoder_pos_conv_kernel=128, - encoder_pos_conv_groups=16, - encoder_num_layers=24, - encoder_num_heads=16, - encoder_attention_dropout=0.0, - encoder_ff_interm_features=4096, - encoder_ff_interm_dropout=0.1, - encoder_dropout=0.0, - encoder_layer_norm_first=True, - encoder_layer_drop=0.1, - aux_num_out=31, -) - - -model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" - ) -) -model.eval() -model.to(device) +from torchaudio.pipelines import MMS_FA + +bundle = MMS_FA +model = bundle.get_model(with_star=False) +dictionary = bundle.get_dict(star=None) def get_emission(waveform): with torch.inference_mode(): - # NOTE: this step is essential - waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) emission, _ = model(waveform) return torch.log_softmax(emission, dim=-1) -# Construct the dictionary -# '@' represents the OOV token -# and are fairseq's legacy tokens, which're not used. -# token is omitted as we do not use it in this tutorial -dictionary = { - "": 0, - "": 1, - "": 2, - "@": 3, - "a": 4, - "i": 5, - "e": 6, - "n": 7, - "o": 8, - "u": 9, - "t": 10, - "s": 11, - "r": 12, - "m": 13, - "k": 14, - "l": 15, - "d": 16, - "g": 17, - "h": 18, - "y": 19, - "b": 20, - "p": 21, - "w": 22, - "c": 23, - "v": 24, - "j": 25, - "z": 26, - "f": 27, - "'": 28, - "q": 29, - "x": 30, -} - - ###################################################################### # Before aligning the speech with transcripts, we need to make sure # the transcripts are already romanized. Here are the BASH commands diff --git a/torchaudio/pipelines/__init__.py b/torchaudio/pipelines/__init__.py index 267526d4464..f7f9b0412d7 100644 --- a/torchaudio/pipelines/__init__.py +++ b/torchaudio/pipelines/__init__.py @@ -18,6 +18,7 @@ HUBERT_BASE, HUBERT_LARGE, HUBERT_XLARGE, + MMS_FA, VOXPOPULI_ASR_BASE_10K_DE, VOXPOPULI_ASR_BASE_10K_EN, VOXPOPULI_ASR_BASE_10K_ES, @@ -77,6 +78,7 @@ "HUBERT_XLARGE", "HUBERT_ASR_LARGE", "HUBERT_ASR_XLARGE", + "MMS_FA", "WAVLM_BASE", "WAVLM_BASE_PLUS", "WAVLM_LARGE", diff --git a/torchaudio/pipelines/_wav2vec2/impl.py b/torchaudio/pipelines/_wav2vec2/impl.py index 6a8faf1127a..d3288411ed3 100644 --- a/torchaudio/pipelines/_wav2vec2/impl.py +++ b/torchaudio/pipelines/_wav2vec2/impl.py @@ -1,5 +1,6 @@ +import copy from dataclasses import dataclass -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from torch.nn import Module @@ -146,7 +147,7 @@ def get_labels( *, blank: str = "-", ) -> Tuple[str, ...]: - """The output class labels (only applicable to fine-tuned bundles) + """The output class labels. The first is blank token, and it is customizable. @@ -159,8 +160,8 @@ def get_labels( the output class labels. Example - >>> import torchaudio - >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() + >>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle + >>> bundle.get_labels() ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') """ # noqa: E501 return (blank, *self._labels) @@ -1518,3 +1519,125 @@ def _get_state_dict(self, dl_kwargs): Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details. """ # noqa: E501 + + +@dataclass +class Wav2Vec2FABundle(Wav2Vec2ASRBundle): + def get_labels(self, star: Optional[str] = "", blank: str = "") -> Tuple[str, ...]: + """Get the labels corresponding to the feature dimension of emission. + + The first is blank token, and it is customizable. + + Args: + star (str or None, optional): Change or disable star token. (default: ``""``) + blank (str, optional): Change the blank token. (default: ``'-'``) + + Returns: + Tuple[str, ...]: + For models fine-tuned on ASR, returns the tuple of strings representing + the output class labels. + + Example + >>> from torchaudio.pipelines import MMS_FA as bundle + >>> bundle.get_labels() + ('', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '') + >>> bundle.get_labels(star=None) + ('', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x') + """ # noqa: E501 + labels = super().get_labels(blank=blank) + return labels if star is None else (*labels, star) + + def _get_params_with_star(self): + params = copy.deepcopy(self._params) + params["aux_num_out"] += 1 + return params + + def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module: + """Construct the model and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + with_star (bool, optional): If enabled, the last dimension of output layer is + extended by one, which corresponds to `star` token. + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`. + """ + params = self._get_params_with_star() if with_star else self._params + model = utils._get_model(self._model_type, params) + state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star) + model.load_state_dict(state_dict) + if self._normalize_waveform: + model = utils._apply_input_layer_norm(model) + model.eval() + return model + + def get_dict(self, star: Optional[str] = "", blank: str = "") -> Dict[str, int]: + """Get the mapping from token to index (in emission feature dim) + + Args: + star (str or None, optional): Change or disable star token. (default: ``""``) + blank (str, optional): Change the blank token. (default: ``'-'``) + + Returns: + Tuple[str, ...]: + For models fine-tuned on ASR, returns the tuple of strings representing + the output class labels. + + Example + >>> from torchaudio.pipelines import MMS_FA as bundle + >>> bundle.get_dict() + {'': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '': 28} + >>> bundle.get_dict(star=None) + {'': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27} + """ # noqa: E501 + return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))} + + +MMS_FA = Wav2Vec2FABundle( + "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt", + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.0, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.0, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.1, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.1, + "aux_num_out": 28, + }, + _labels=utils._get_mms_labels(), + _sample_rate=16000, + _normalize_waveform=True, + _model_type="Wav2Vec2", +) +MMS_FA.__doc__ = """ +Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`. + +Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License `__]. + +Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details. + +.. note:: + + Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different. +""" # noqa: E501 diff --git a/torchaudio/pipelines/_wav2vec2/utils.py b/torchaudio/pipelines/_wav2vec2/utils.py index 0ab459f34ae..69e869208b2 100644 --- a/torchaudio/pipelines/_wav2vec2/utils.py +++ b/torchaudio/pipelines/_wav2vec2/utils.py @@ -65,13 +65,23 @@ def _remove_aux_axes(state_dict, axes): state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) -def _get_state_dict(url, dl_kwargs, remove_axes=None): +def _add_star_dim(state_dict): + w, b = state_dict["aux.weight"], state_dict["aux.bias"] + zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype) + state_dict["aux.weight"] = torch.cat((zeros, w), dim=0) + ones = torch.ones((1,), device=b.device, dtype=b.dtype) + state_dict["aux.bias"] = torch.cat((b, ones), dim=0) + + +def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False): if not url.startswith("https"): url = f"https://download.pytorch.org/torchaudio/models/{url}" dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(url, **dl_kwargs) if remove_axes: _remove_aux_axes(state_dict, remove_axes) + if add_star: + _add_star_dim(state_dict) return state_dict @@ -301,3 +311,35 @@ def _get_it_labels(): "í", "ï", ) + + +def _get_mms_labels(): + return ( + "a", + "i", + "e", + "n", + "o", + "u", + "t", + "s", + "r", + "m", + "k", + "l", + "d", + "g", + "h", + "y", + "b", + "p", + "w", + "c", + "v", + "j", + "z", + "f", + "'", + "q", + "x", + )