diff --git a/docs/source/_templates/autosummary/model_class.rst b/docs/source/_templates/autosummary/model_class.rst index c843f503c1..283f904151 100644 --- a/docs/source/_templates/autosummary/model_class.rst +++ b/docs/source/_templates/autosummary/model_class.rst @@ -46,6 +46,14 @@ "hdemucs_medium", "hdemucs_high", ], + "torchaudio.models.SquimObjective": [ + "squim_objective_model", + "squim_objective_base", + ], + "torchaudio.models.SquimSubjective": [ + "squim_subjective_model", + "squim_subjective_base", + ], } -%} {%- set prototype_factory = { diff --git a/docs/source/_templates/autosummary/prototype_model_class.rst b/docs/source/_templates/autosummary/prototype_model_class.rst index 0c1ebe9757..9dc4bab50d 100644 --- a/docs/source/_templates/autosummary/prototype_model_class.rst +++ b/docs/source/_templates/autosummary/prototype_model_class.rst @@ -13,14 +13,6 @@ } -%} {%- set factory={ - "torchaudio.prototype.models.SquimObjective": [ - "squim_objective_model", - "squim_objective_base", - ], - "torchaudio.prototype.models.SquimSubjective": [ - "squim_subjective_model", - "squim_subjective_base", - ], "torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [ "conformer_wav2vec2_pretrain_model", "conformer_wav2vec2_pretrain_base", diff --git a/docs/source/models.rst b/docs/source/models.rst index 9c7b7fc494..cd7bfd2660 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -28,6 +28,8 @@ For such models, factory functions are provided. HuBERTPretrainModel RNNT RNNTBeamSearch + SquimObjective + SquimSubjective Tacotron2 Wav2Letter Wav2Vec2Model diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst index fb4c76db2b..fd9205a5fa 100644 --- a/docs/source/pipelines.rst +++ b/docs/source/pipelines.rst @@ -217,3 +217,53 @@ Pretrained Models CONVTASNET_BASE_LIBRI2MIX HDEMUCS_HIGH_MUSDB_PLUS HDEMUCS_HIGH_MUSDB + +Squim Objective +--------------- + +Interface +~~~~~~~~~ + +:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform. + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_class.rst + + SquimObjectiveBundle + +Pretrained Models +~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_data.rst + + SQUIM_OBJECTIVE + +Squim Subjective +---------------- + +Interface +~~~~~~~~~ + +:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform. + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_class.rst + + SquimSubjectiveBundle + +Pretrained Models +~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/bundle_data.rst + + SQUIM_SUBJECTIVE diff --git a/docs/source/prototype.models.rst b/docs/source/prototype.models.rst index fd71794be5..913b019fcb 100644 --- a/docs/source/prototype.models.rst +++ b/docs/source/prototype.models.rst @@ -23,8 +23,6 @@ For such models, factory functions are provided. ConformerWav2Vec2PretrainModel ConvEmformer HiFiGANVocoder - SquimObjective - SquimSubjective Prototype Factory Functions of Beta Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/prototype.pipelines.rst b/docs/source/prototype.pipelines.rst index d713da582b..e9804337c0 100644 --- a/docs/source/prototype.pipelines.rst +++ b/docs/source/prototype.pipelines.rst @@ -45,53 +45,3 @@ Pretrained Models :template: autosummary/bundle_data.rst HIFIGAN_VOCODER_V3_LJSPEECH - -Squim Objective ---------------- - -Interface -~~~~~~~~~ - -:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform. - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: autosummary/bundle_class.rst - - SquimObjectiveBundle - -Pretrained Models -~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: autosummary/bundle_data.rst - - SQUIM_OBJECTIVE - -Squim Subjective ----------------- - -Interface -~~~~~~~~~ - -:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform. - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: autosummary/bundle_class.rst - - SquimSubjectiveBundle - -Pretrained Models -~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: autosummary/bundle_data.rst - - SQUIM_SUBJECTIVE diff --git a/examples/tutorials/squim_tutorial.py b/examples/tutorials/squim_tutorial.py index 5314915554..640e2e79b8 100644 --- a/examples/tutorials/squim_tutorial.py +++ b/examples/tutorials/squim_tutorial.py @@ -80,7 +80,7 @@ try: from pesq import pesq from pystoi import stoi - from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE + from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE except ImportError: import google.colab # noqa: F401 diff --git a/test/integration_tests/prototype/squim_pipeline_test.py b/test/integration_tests/squim_pipeline_test.py similarity index 94% rename from test/integration_tests/prototype/squim_pipeline_test.py rename to test/integration_tests/squim_pipeline_test.py index acec9e4ea4..9f78bba4d4 100644 --- a/test/integration_tests/prototype/squim_pipeline_test.py +++ b/test/integration_tests/squim_pipeline_test.py @@ -1,6 +1,6 @@ import pytest import torchaudio -from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE +from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE @pytest.mark.parametrize( diff --git a/test/torchaudio_unittest/models/squim/__init__.py b/test/torchaudio_unittest/models/squim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/torchaudio_unittest/prototype/squim_test.py b/test/torchaudio_unittest/models/squim/squim_test.py similarity index 97% rename from test/torchaudio_unittest/prototype/squim_test.py rename to test/torchaudio_unittest/models/squim/squim_test.py index 2d32399c90..7d3d2573e0 100644 --- a/test/torchaudio_unittest/prototype/squim_test.py +++ b/test/torchaudio_unittest/models/squim/squim_test.py @@ -1,6 +1,6 @@ import torch from parameterized import parameterized -from torchaudio.prototype.models import squim_objective_base, squim_subjective_base +from torchaudio.models import squim_objective_base, squim_subjective_base from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py index f6458b1f86..5d344400d3 100644 --- a/torchaudio/models/__init__.py +++ b/torchaudio/models/__init__.py @@ -5,6 +5,14 @@ from .emformer import Emformer from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT from .rnnt_decoder import Hypothesis, RNNTBeamSearch +from .squim import ( + squim_objective_base, + squim_objective_model, + squim_subjective_base, + squim_subjective_model, + SquimObjective, + SquimSubjective, +) from .tacotron2 import Tacotron2 from .wav2letter import Wav2Letter from .wav2vec2 import ( @@ -68,4 +76,10 @@ "hdemucs_low", "hdemucs_medium", "hdemucs_high", + "squim_objective_base", + "squim_objective_model", + "squim_subjective_base", + "squim_subjective_model", + "SquimObjective", + "SquimSubjective", ] diff --git a/torchaudio/prototype/models/squim/__init__.py b/torchaudio/models/squim/__init__.py similarity index 100% rename from torchaudio/prototype/models/squim/__init__.py rename to torchaudio/models/squim/__init__.py diff --git a/torchaudio/prototype/models/squim/objective.py b/torchaudio/models/squim/objective.py similarity index 100% rename from torchaudio/prototype/models/squim/objective.py rename to torchaudio/models/squim/objective.py diff --git a/torchaudio/prototype/models/squim/subjective.py b/torchaudio/models/squim/subjective.py similarity index 100% rename from torchaudio/prototype/models/squim/subjective.py rename to torchaudio/models/squim/subjective.py diff --git a/torchaudio/pipelines/__init__.py b/torchaudio/pipelines/__init__.py index 586073c188..267526d446 100644 --- a/torchaudio/pipelines/__init__.py +++ b/torchaudio/pipelines/__init__.py @@ -4,6 +4,7 @@ HDEMUCS_HIGH_MUSDB_PLUS, SourceSeparationBundle, ) +from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle from ._tts import ( TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, @@ -90,4 +91,8 @@ "CONVTASNET_BASE_LIBRI2MIX", "HDEMUCS_HIGH_MUSDB_PLUS", "HDEMUCS_HIGH_MUSDB", + "SQUIM_OBJECTIVE", + "SQUIM_SUBJECTIVE", + "SquimObjectiveBundle", + "SquimSubjectiveBundle", ] diff --git a/torchaudio/prototype/pipelines/squim_pipeline.py b/torchaudio/pipelines/_squim_pipeline.py similarity index 88% rename from torchaudio/prototype/pipelines/squim_pipeline.py rename to torchaudio/pipelines/_squim_pipeline.py index e5cec3d878..f1d11bff53 100644 --- a/torchaudio/prototype/pipelines/squim_pipeline.py +++ b/torchaudio/pipelines/_squim_pipeline.py @@ -2,13 +2,13 @@ from torchaudio._internal import load_state_dict_from_url -from torchaudio.prototype.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective +from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective @dataclass class SquimObjectiveBundle: """Data class that bundles associated information to use pretrained - :py:class:`~torchaudio.prototype.models.SquimObjective` model. + :py:class:`~torchaudio.models.SquimObjective` model. This class provides interfaces for instantiating the pretrained model along with the information necessary to retrieve pretrained weights and additional data @@ -24,8 +24,7 @@ class SquimObjectiveBundle: Example: Estimate the objective metric scores for the input waveform. >>> import torch >>> import torchaudio - >>> # Since SquimObjective bundle is in prototypes, it needs to be exported explicitly - >>> from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE as bundle + >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle >>> >>> # Load the SquimObjective bundle >>> model = bundle.get_model() @@ -59,7 +58,7 @@ def get_model(self, *, dl_kwargs=None) -> SquimObjective: dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. Returns: - Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`. + Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_objective_base() model.load_state_dict(self._get_state_dict(dl_kwargs)) @@ -82,7 +81,7 @@ def sample_rate(self): SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`. - The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_objective_base`. + The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`. The weights are under `Creative Commons Attribution 4.0 International License `__. @@ -93,7 +92,7 @@ def sample_rate(self): @dataclass class SquimSubjectiveBundle: """Data class that bundles associated information to use pretrained - :py:class:`~torchaudio.prototype.models.SquimSubjective` model. + :py:class:`~torchaudio.models.SquimSubjective` model. This class provides interfaces for instantiating the pretrained model along with the information necessary to retrieve pretrained weights and additional data @@ -109,8 +108,7 @@ class SquimSubjectiveBundle: Example: Estimate the subjective metric scores for the input waveform. >>> import torch >>> import torchaudio - >>> # Since SquimSubjective bundle is in prototypes, it needs to be exported explicitly - >>> from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE as bundle + >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle >>> >>> # Load the SquimSubjective bundle >>> model = bundle.get_model() @@ -146,7 +144,7 @@ def get_model(self, *, dl_kwargs=None) -> SquimSubjective: dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. Returns: - Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`. + Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_subjective_base() model.load_state_dict(self._get_state_dict(dl_kwargs)) @@ -170,7 +168,7 @@ def sample_rate(self): as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio` on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets. - The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_subjective_base`. + The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`. The weights are under `Creative Commons Attribution Non Commercial 4.0 International `__. diff --git a/torchaudio/prototype/models/__init__.py b/torchaudio/prototype/models/__init__.py index fcc8809f20..71134100dc 100644 --- a/torchaudio/prototype/models/__init__.py +++ b/torchaudio/prototype/models/__init__.py @@ -11,14 +11,6 @@ from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing -from .squim import ( - squim_objective_base, - squim_objective_model, - squim_subjective_base, - squim_subjective_model, - SquimObjective, - SquimSubjective, -) __all__ = [ "conformer_rnnt_base", @@ -42,10 +34,4 @@ "hifigan_vocoder_v2", "hifigan_vocoder_v3", "hifigan_vocoder", - "squim_objective_base", - "squim_objective_model", - "squim_subjective_base", - "squim_subjective_model", - "SquimObjective", - "SquimSubjective", ] diff --git a/torchaudio/prototype/pipelines/__init__.py b/torchaudio/prototype/pipelines/__init__.py index 8cd6227eaa..5cfc48757c 100644 --- a/torchaudio/prototype/pipelines/__init__.py +++ b/torchaudio/prototype/pipelines/__init__.py @@ -1,14 +1,9 @@ from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 -from .squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle __all__ = [ "EMFORMER_RNNT_BASE_MUSTC", "EMFORMER_RNNT_BASE_TEDLIUM3", "HIFIGAN_VOCODER_V3_LJSPEECH", "HiFiGANVocoderBundle", - "SQUIM_OBJECTIVE", - "SQUIM_SUBJECTIVE", - "SquimObjectiveBundle", - "SquimSubjectiveBundle", ]