Skip to content

Commit

Permalink
adds decord video decoder (#122)
Browse files Browse the repository at this point in the history
Summary:
## Motivation and Context

Adds Decord Video decoder (https://github.com/dmlc/decord) as one of the backends for video decoding.

## How Has This Been Tested

Added appropriate unit tests.

## Types of changes

- [x] Docs change / refactoring / dependency upgrade
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

Pull Request resolved: #122

Test Plan:
buck test mode/dev-nosan //vision/fair/pytorchvideo/...

f305615820

Reviewed By: digdoug

Differential Revision: D31158836

Pulled By: kalyanvasudev

fbshipit-source-id: 4fac324d7b845051da5c2a91cac892b807828e3d
  • Loading branch information
kalyanvasudev authored and facebook-github-bot committed Oct 28, 2021
1 parent a94179a commit 35b1ca5
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 2 deletions.
1 change: 1 addition & 0 deletions pytorchvideo/data/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
class DecoderType(Enum):
PYAV = "pyav"
TORCHVISION = "torchvision"
DECORD = "decord"
4 changes: 4 additions & 0 deletions pytorchvideo/data/encoded_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def select_video_class(decoder: str) -> Video:
from .encoded_video_torchvision import EncodedVideoTorchVision

video_cls = EncodedVideoTorchVision
elif DecoderType(decoder) == DecoderType.DECORD:
from .encoded_video_decord import EncodedVideoDecord

video_cls = EncodedVideoDecord
else:
raise NotImplementedError(f"Unknown decoder type {decoder}")

Expand Down
194 changes: 194 additions & 0 deletions pytorchvideo/data/encoded_video_decord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import logging
import math
from typing import BinaryIO, Dict, Optional, TypeVar

import torch

from .utils import thwc_to_cthw
from .video import Video


logger = logging.getLogger(__name__)

try:
import decord
except ImportError:
_HAS_DECORD = False
else:
_HAS_DECORD = True

if _HAS_DECORD:
decord.bridge.set_bridge("torch")

DecordDevice = TypeVar("DecordDevice")


class EncodedVideoDecord(Video):
"""
Accessing clips from an encoded video using Decord video reading API
as the decoding backend. For more details, please refer to -
`Decord <https://github.com/dmlc/decord>`
"""

def __init__(
self,
file: BinaryIO,
video_name: Optional[str] = None,
decode_audio: bool = True,
sample_rate: int = 44100,
mono: bool = True,
width: int = -1,
height: int = -1,
num_threads: int = 0,
fault_tol: int = -1,
) -> None:
"""
Args:
file (BinaryIO): a file-like object (e.g. io.BytesIO or io.StringIO) that
contains the encoded video.
video_name (str): An optional name assigned to the video.
decode_audio (bool): If disabled, audio is not decoded.
sample_rate: int, default is -1
Desired output sample rate of the audio, unchanged if `-1` is specified.
mono: bool, default is True
Desired output channel layout of the audio. `True` is mono layout. `False`
is unchanged.
width : int, default is -1
Desired output width of the video, unchanged if `-1` is specified.
height : int, default is -1
Desired output height of the video, unchanged if `-1` is specified.
num_threads : int, default is 0
Number of decoding thread, auto if `0` is specified.
fault_tol : int, default is -1
The threshold of corupted and recovered frames. This is to prevent silent fault
tolerance when for example 50% frames of a video cannot be decoded and duplicate
frames are returned. You may find the fault tolerant feature sweet in many
cases, but not for training models. Say `N = # recovered frames`
If `fault_tol` < 0, nothing will happen.
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`,
raise `DECORDLimitReachedError`.
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
"""
self._decode_audio = decode_audio
self._video_name = video_name
if not _HAS_DECORD:
raise ImportError(
"decord is required to use EncodedVideoDecord decoder. Please "
"install with 'pip install decord' for CPU-only version and refer to"
"'https://github.com/dmlc/decord' for GPU-supported version"
)
try:
if self._decode_audio:
self._av_reader = decord.AVReader(
uri=file,
ctx=decord.cpu(0),
sample_rate=sample_rate,
mono=mono,
width=width,
height=height,
num_threads=num_threads,
fault_tol=fault_tol,
)
else:
self._av_reader = decord.VideoReader(
uri=file,
ctx=decord.cpu(0),
width=width,
height=height,
num_threads=num_threads,
fault_tol=fault_tol,
)
except Exception as e:
raise RuntimeError(f"Failed to open video {video_name} with Decord. {e}")

if self._decode_audio:
self._fps = self._av_reader._AVReader__video_reader.get_avg_fps()
else:
self._fps = self._av_reader.get_avg_fps()

self._duration = float(len(self._av_reader)) / float(self._fps)

@property
def name(self) -> Optional[str]:
"""
Returns:
name: the name of the stored video if set.
"""
return self._video_name

@property
def duration(self) -> float:
"""
Returns:
duration: the video's duration/end-time in seconds.
"""
return self._duration

def close(self):
if self._av_reader is not None:
del self._av_reader
self._av_reader = None

def get_clip(
self, start_sec: float, end_sec: float
) -> Dict[str, Optional[torch.Tensor]]:
"""
Retrieves frames from the encoded video at the specified start and end times
in seconds (the video always starts at 0 seconds).
Args:
start_sec (float): the clip start time in seconds
end_sec (float): the clip end time in seconds
Returns:
clip_data:
A dictionary mapping the entries at "video" and "audio" to a tensors.
"video": A tensor of the clip's RGB frames with shape:
(channel, time, height, width). The frames are of type torch.float32 and
in the range [0 - 255].
"audio": A tensor of the clip's audio samples with shape:
(samples). The samples are of type torch.float32 and
in the range [0 - 255].
Returns None if no video or audio found within time range.
"""
if start_sec > end_sec or start_sec > self._duration:
raise RuntimeError(
f"Incorrect time window for Decord decoding for video: {self._video_name}."
)

start_idx = math.ceil(self._fps * start_sec)
end_idx = math.ceil(self._fps * end_sec)
end_idx = min(end_idx, len(self._av_reader))
frame_idxs = list(range(start_idx, end_idx))
audio = None

try:
outputs = self._av_reader.get_batch(frame_idxs)
except Exception as e:
logger.debug(f"Failed to decode video with Decord: {self._video_name}. {e}")
raise e

if self._decode_audio:
audio, video = outputs
if audio is not None:
audio = list(audio)
audio = torch.cat(audio, dim=1)
audio = torch.flatten(audio)
audio = audio.to(torch.float32)
else:
video = outputs

if video is not None:
video = video.to(torch.float32)
video = thwc_to_cthw(video)

return {
"video": video,
"audio": audio,
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def get_name():
"networkx",
],
extras_require={
"test": ["coverage", "pytest", "opencv-python"],
"test": ["coverage", "pytest", "opencv-python", "decord"],
"dev": [
"opencv-python",
"decord",
"black==20.8b1",
"sphinx",
"isort==4.3.21",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_labeled_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)


DECODER_LIST = [("pyav",), ("torchvision",)]
DECODER_LIST = [("pyav",), ("torchvision",), ("decord",)]


class TestLabeledVideoDataset(unittest.TestCase):
Expand Down

0 comments on commit 35b1ca5

Please sign in to comment.