diff --git a/examples/avsr/data_prep/data/data_module.py b/examples/avsr/data_prep/data/data_module.py index 72bc2e69e7..542e26147a 100644 --- a/examples/avsr/data_prep/data/data_module.py +++ b/examples/avsr/data_prep/data/data_module.py @@ -32,8 +32,8 @@ def load_data(self, data_filename, transform=True): audio = self.audio_process(audio, sample_rate) return audio if self.modality == "video": - landmarks = self.landmarks_detector(data_filename) video = self.load_video(data_filename) + landmarks = self.landmarks_detector(video) video = self.video_process(video, landmarks) video = torch.tensor(video) return video diff --git a/examples/avsr/data_prep/detectors/mediapipe/detector.py b/examples/avsr/data_prep/detectors/mediapipe/detector.py index 9971dde2b5..e3875c2f2f 100644 --- a/examples/avsr/data_prep/detectors/mediapipe/detector.py +++ b/examples/avsr/data_prep/detectors/mediapipe/detector.py @@ -9,7 +9,6 @@ import mediapipe as mp import numpy as np -import torchvision warnings.filterwarnings("ignore") @@ -29,8 +28,7 @@ def __call__(self, video_frames): assert any(l is not None for l in landmarks), "Cannot detect any frames in the video" return landmarks - def detect(self, filename, detector): - video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy() + def detect(self, video_frames, detector): landmarks = [] for frame in video_frames: results = detector.process(frame) diff --git a/examples/avsr/data_prep/detectors/retinaface/detector.py b/examples/avsr/data_prep/detectors/retinaface/detector.py index 2044627045..f35fdf97d2 100644 --- a/examples/avsr/data_prep/detectors/retinaface/detector.py +++ b/examples/avsr/data_prep/detectors/retinaface/detector.py @@ -7,7 +7,6 @@ import warnings import numpy as np -import torchvision from ibug.face_detection import RetinaFacePredictor warnings.filterwarnings("ignore") @@ -19,8 +18,7 @@ def __init__(self, device="cuda:0", model_name="resnet50"): device=device, threshold=0.8, model=RetinaFacePredictor.get_model(model_name) ) - def __call__(self, filename): - video_frames = torchvision.io.read_video(filename, pts_unit="sec")[0].numpy() + def __call__(self, video_frames): landmarks = [] for frame in video_frames: detected_faces = self.face_detector(frame, rgb=False)