diff --git a/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py b/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py index 8114e0d4af..7b0f527e73 100644 --- a/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +++ b/torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Callable, Dict import torch import torchaudio @@ -6,6 +7,11 @@ from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor +def _get_state_dict(): + path = torchaudio.utils.download_asset("models/vggish.pt") + return torch.load(path) + + @dataclass class VGGishBundle: """VGGish :cite:`45611` inference pipeline ported from @@ -34,7 +40,7 @@ class VGGish(_VGGish): class VGGishInputProcessor(_VGGishInputProcessor): __doc__ = _VGGishInputProcessor.__doc__ - _weights_path: str + _state_dict_func: Callable[[], Dict] @property def sample_rate(self) -> int: @@ -51,8 +57,7 @@ def get_model(self) -> VGGish: VGGish: VGGish model with pre-trained weights loaded. """ model = self.VGGish() - path = torchaudio.utils.download_asset(self._weights_path) - state_dict = torch.load(path) + state_dict = self._state_dict_func() model.load_state_dict(state_dict) model.eval() return model @@ -66,7 +71,7 @@ def get_input_processor(self) -> VGGishInputProcessor: return self.VGGishInputProcessor() -VGGISH = VGGishBundle("models/vggish.pt") +VGGISH = VGGishBundle(_get_state_dict) VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from `torchvggish `__ and `tensorflow-models `__.