Skip to content

Commit

Permalink
Revise VGGish pipeline to accept arbitrary state dict function
Browse files Browse the repository at this point in the history
Summary: Revises VGGish pipeline to accept arbitrary state dict function to accommodate loading weights from any source.

Differential Revision: D48056390

fbshipit-source-id: a0fa46c4ca266a49db3b73c2809f68d979d2fa26
  • Loading branch information
hwangjeff authored and facebook-github-bot committed Aug 4, 2023
1 parent b645c07 commit a934f25
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from dataclasses import dataclass
from typing import Callable, Dict

import torch
import torchaudio

from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor


def _get_state_dict():
path = torchaudio.utils.download_asset("models/vggish.pt")
return torch.load(path)


@dataclass
class VGGishBundle:
"""VGGish :cite:`45611` inference pipeline ported from
Expand Down Expand Up @@ -34,7 +40,7 @@ class VGGish(_VGGish):
class VGGishInputProcessor(_VGGishInputProcessor):
__doc__ = _VGGishInputProcessor.__doc__

_weights_path: str
_state_dict_func: Callable[[], Dict]

@property
def sample_rate(self) -> int:
Expand All @@ -51,8 +57,7 @@ def get_model(self) -> VGGish:
VGGish: VGGish model with pre-trained weights loaded.
"""
model = self.VGGish()
path = torchaudio.utils.download_asset(self._weights_path)
state_dict = torch.load(path)
state_dict = self._state_dict_func()
model.load_state_dict(state_dict)
model.eval()
return model
Expand All @@ -66,7 +71,7 @@ def get_input_processor(self) -> VGGishInputProcessor:
return self.VGGishInputProcessor()


VGGISH = VGGishBundle("models/vggish.pt")
VGGISH = VGGishBundle(_get_state_dict)
VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from
`torchvggish <https://github.com/harritaylor/torchvggish>`__
and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__.
Expand Down

0 comments on commit a934f25

Please sign in to comment.