Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TorchAudio-Squim models to Beta #3512

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/_templates/autosummary/model_class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@
"hdemucs_medium",
"hdemucs_high",
],
"torchaudio.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
}
-%}
{%- set prototype_factory = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@
}
-%}
{%- set factory={
"torchaudio.prototype.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.prototype.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
"torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [
"conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base",
Expand Down
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ For such models, factory functions are provided.
HuBERTPretrainModel
RNNT
RNNTBeamSearch
SquimObjective
SquimSubjective
Tacotron2
Wav2Letter
Wav2Vec2Model
Expand Down
50 changes: 50 additions & 0 deletions docs/source/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,53 @@ Pretrained Models
CONVTASNET_BASE_LIBRI2MIX
HDEMUCS_HIGH_MUSDB_PLUS
HDEMUCS_HIGH_MUSDB

Squim Objective
---------------

Interface
~~~~~~~~~

:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst

SquimObjectiveBundle

Pretrained Models
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst

SQUIM_OBJECTIVE

Squim Subjective
----------------

Interface
~~~~~~~~~

:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst

SquimSubjectiveBundle

Pretrained Models
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst

SQUIM_SUBJECTIVE
2 changes: 0 additions & 2 deletions docs/source/prototype.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ For such models, factory functions are provided.
ConformerWav2Vec2PretrainModel
ConvEmformer
HiFiGANVocoder
SquimObjective
SquimSubjective

Prototype Factory Functions of Beta Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
50 changes: 0 additions & 50 deletions docs/source/prototype.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,53 +45,3 @@ Pretrained Models
:template: autosummary/bundle_data.rst

HIFIGAN_VOCODER_V3_LJSPEECH

Squim Objective
---------------

Interface
~~~~~~~~~

:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst

SquimObjectiveBundle

Pretrained Models
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst

SQUIM_OBJECTIVE

Squim Subjective
----------------

Interface
~~~~~~~~~

:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst

SquimSubjectiveBundle

Pretrained Models
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst

SQUIM_SUBJECTIVE
2 changes: 1 addition & 1 deletion examples/tutorials/squim_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
try:
from pesq import pesq
from pystoi import stoi
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
import google.colab # noqa: F401

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torchaudio
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE


@pytest.mark.parametrize(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from parameterized import parameterized
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base
from torchaudio.models import squim_objective_base, squim_subjective_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase


Expand Down
14 changes: 14 additions & 0 deletions torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from .emformer import Emformer
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import (
Expand Down Expand Up @@ -68,4 +76,10 @@
"hdemucs_low",
"hdemucs_medium",
"hdemucs_high",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
]
5 changes: 5 additions & 0 deletions torchaudio/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle,
)
from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
from ._tts import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
Expand Down Expand Up @@ -90,4 +91,8 @@
"CONVTASNET_BASE_LIBRI2MIX",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from torchaudio._internal import load_state_dict_from_url

from torchaudio.prototype.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective


@dataclass
class SquimObjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimObjective` model.
:py:class:`~torchaudio.models.SquimObjective` model.

This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
Expand All @@ -24,8 +24,7 @@ class SquimObjectiveBundle:
Example: Estimate the objective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimObjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE as bundle
>>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
>>>
>>> # Load the SquimObjective bundle
>>> model = bundle.get_model()
Expand Down Expand Up @@ -59,7 +58,7 @@ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.

Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
Variation of :py:class:`~torchaudio.models.SquimObjective`.
"""
model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
Expand All @@ -82,7 +81,7 @@ def sample_rate(self):
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.

The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_objective_base`.
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
The weights are under `Creative Commons Attribution 4.0 International License
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.

Expand All @@ -93,7 +92,7 @@ def sample_rate(self):
@dataclass
class SquimSubjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimSubjective` model.
:py:class:`~torchaudio.models.SquimSubjective` model.

This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
Expand All @@ -109,8 +108,7 @@ class SquimSubjectiveBundle:
Example: Estimate the subjective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimSubjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE as bundle
>>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
>>>
>>> # Load the SquimSubjective bundle
>>> model = bundle.get_model()
Expand Down Expand Up @@ -146,7 +144,7 @@ def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.

Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
Variation of :py:class:`~torchaudio.models.SquimObjective`.
"""
model = squim_subjective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
Expand All @@ -170,7 +168,7 @@ def sample_rate(self):
as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.

The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_subjective_base`.
The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
The weights are under `Creative Commons Attribution Non Commercial 4.0 International
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.

Expand Down
14 changes: 0 additions & 14 deletions torchaudio/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)

__all__ = [
"conformer_rnnt_base",
Expand All @@ -42,10 +34,4 @@
"hifigan_vocoder_v2",
"hifigan_vocoder_v3",
"hifigan_vocoder",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
]
5 changes: 0 additions & 5 deletions torchaudio/prototype/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle

__all__ = [
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
]
Loading