Skip to content

Commit

Permalink
Refactor compat (#3518)
Browse files Browse the repository at this point in the history
Summary:
The I/O functions in _compat module was introduced there so that
everything related to FFmpeg is in torchaudio.io and FFmpeg library
initialization can be carried out in `torchaudio.io.__init__`.

Now that this constraint is removed, (all the initialization happens
at `torchaudio._extension.__init__`) and `_compat` is only used by
FFmpeg dispatcher backend, we move the module to `torchaudio._backend`
for better locality.

Pull Request resolved: #3518

Reviewed By: huangruizhe

Differential Revision: D47877412

Pulled By: mthrok

fbshipit-source-id: aa18c8cb6e5d5360950df5158c33c653e37c565f
  • Loading branch information
mthrok authored and facebook-github-bot committed Jul 29, 2023
1 parent 61cbf79 commit 8497ee9
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 15 deletions.
5 changes: 2 additions & 3 deletions examples/tutorials/device_avsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@
"""

import numpy as np
import sentencepiece as spm
import torch
import torchaudio
import torchvision

import numpy as np
import sentencepiece as spm

######################################################################
# Overview
# --------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from functools import partial

from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import name_func
from torchaudio_unittest.common_utils import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import torch
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/io/_compat.py → torchaudio/_backend/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

if torchaudio._extension._FFMPEG_EXT is not None:
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
StreamReaderFileObj = object


# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
def info_audio(
src: str,
format: Optional[str],
Expand Down Expand Up @@ -241,7 +242,6 @@ def _type(spec):
return muxer, encoder, sample_fmt


# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
def save_audio(
uri: Union[BinaryIO, str, os.PathLike],
src: torch.Tensor,
Expand Down
16 changes: 8 additions & 8 deletions torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import BinaryIO, Dict, Optional, Tuple, Union

import torch
import torchaudio.backend.soundfile_backend as soundfile_backend

from torchaudio._extension import _FFMPEG_EXT, _SOX_INITIALIZED
from torchaudio.backend import soundfile_backend
from torchaudio.backend.common import AudioMetaData

if _FFMPEG_EXT is not None:
from torchaudio.io._compat import info_audio, info_audio_fileobj, load_audio, load_audio_fileobj, save_audio
from . import ffmpeg


class Backend(ABC):
Expand Down Expand Up @@ -80,9 +80,9 @@ class FFmpegBackend(Backend):
@staticmethod
def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
metadata = ffmpeg.info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(os.path.normpath(uri), format)
metadata = ffmpeg.info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
Expand All @@ -98,7 +98,7 @@ def load(
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
if hasattr(uri, "read"):
return load_audio_fileobj(
return ffmpeg.load_audio_fileobj(
uri,
frame_offset,
num_frames,
Expand All @@ -108,7 +108,7 @@ def load(
buffer_size,
)
else:
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
return ffmpeg.load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)

@staticmethod
def save(
Expand All @@ -121,7 +121,7 @@ def save(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
save_audio(
ffmpeg.save_audio(
uri,
src,
sample_rate,
Expand Down

0 comments on commit 8497ee9

Please sign in to comment.