From 2d8aae4b7b372ee2a19a5afed700ad7aae228adb Mon Sep 17 00:00:00 2001 From: Hossein Rashidi Date: Thu, 31 Oct 2024 12:14:18 +0000 Subject: [PATCH] PR suggestion 1 --- aana_chat_with_video/alembic/env.py | 7 +- .../endpoints/delete_video.py | 4 - .../endpoints/get_video_status.py | 9 +- aana_chat_with_video/endpoints/index_video.py | 203 +++++++++--------- .../endpoints/load_video_metadata.py | 4 - 5 files changed, 109 insertions(+), 118 deletions(-) diff --git a/aana_chat_with_video/alembic/env.py b/aana_chat_with_video/alembic/env.py index 9d7ffc7..e9f3bc1 100644 --- a/aana_chat_with_video/alembic/env.py +++ b/aana_chat_with_video/alembic/env.py @@ -5,6 +5,9 @@ from aana.configs.settings import settings from aana.storage.models.base import BaseEntity +# Import all models to be included in the migration +import aana.storage.models # noqa: F401 +import aana_chat_with_video.storage.models # noqa: F401 # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -20,10 +23,6 @@ # from myapp import mymodel # target_metadata = mymodel.Base.metadata -# Import all models to be included in the migration -import aana.storage.models # noqa: F401 -import aana_chat_with_video.storage.models # noqa: F401 - target_metadata = BaseEntity.metadata # other values from the config, defined by the needs of env.py, diff --git a/aana_chat_with_video/endpoints/delete_video.py b/aana_chat_with_video/endpoints/delete_video.py index 5e19239..1b4fcb1 100644 --- a/aana_chat_with_video/endpoints/delete_video.py +++ b/aana_chat_with_video/endpoints/delete_video.py @@ -17,10 +17,6 @@ class DeleteVideoOutput(TypedDict): class DeleteVideoEndpoint(Endpoint): """Delete video endpoint.""" - async def initialize(self): - """Initialize the endpoint.""" - await super().initialize() - async def run(self, media_id: MediaId) -> DeleteVideoOutput: """Delete video.""" with get_session() as session: diff --git a/aana_chat_with_video/endpoints/get_video_status.py b/aana_chat_with_video/endpoints/get_video_status.py index 01ff10d..f1a0ca2 100644 --- a/aana_chat_with_video/endpoints/get_video_status.py +++ b/aana_chat_with_video/endpoints/get_video_status.py @@ -2,6 +2,7 @@ from aana.api.api_generation import Endpoint from aana.core.models.media import MediaId +from aana.storage.session import get_session from aana_chat_with_video.core.models.video_status import VideoStatus from aana_chat_with_video.storage.repository.extended_video import ( ExtendedVideoRepository, @@ -17,12 +18,8 @@ class VideoStatusOutput(TypedDict): class GetVideoStatusEndpoint(Endpoint): """Get video status endpoint.""" - async def initialize(self): - """Initialize the endpoint.""" - await super().initialize() - self.video_repo = ExtendedVideoRepository(self.session) - async def run(self, media_id: MediaId) -> VideoStatusOutput: """Load video metadata.""" - video_status = self.video_repo.get_status(media_id) + with get_session() as session: + video_status = ExtendedVideoRepository(session).get_status(media_id) return VideoStatusOutput(status=video_status) diff --git a/aana_chat_with_video/endpoints/index_video.py b/aana_chat_with_video/endpoints/index_video.py index b21bd2d..8c96826 100644 --- a/aana_chat_with_video/endpoints/index_video.py +++ b/aana_chat_with_video/endpoints/index_video.py @@ -80,105 +80,106 @@ async def run( # noqa: C901 """Transcribe video in chunks.""" media_id = video.media_id with get_session() as session: - extended_video_repo = ExtendedVideoRepository(session) - transcript_repo = ExtendedVideoTranscriptRepository(session) - caption_repo = ExtendedVideoCaptionRepository(session) - if extended_video_repo.check_media_exists(media_id): + if ExtendedVideoRepository(session).check_media_exists(media_id): raise MediaIdAlreadyExistsException(table_name="media", media_id=video) + + video_duration = None + if video.url is not None: + video_metadata = get_video_metadata(video.url) + video_duration = video_metadata.duration + + # precheck for max video length before actually download the video if possible + if video_duration and video_duration > settings.max_video_len: + raise VideoTooLongException( + video=video, + video_len=video_duration, + max_len=settings.max_video_len, + ) + + video_obj: Video = await run_remote(download_video)(video_input=video) + if video_duration is None: + video_duration = await run_remote(get_video_duration)(video=video_obj) + + if video_duration > settings.max_video_len: + raise VideoTooLongException( + video=video_obj, + video_len=video_duration, + max_len=settings.max_video_len, + ) - video_duration = None - if video.url is not None: - video_metadata = get_video_metadata(video.url) - video_duration = video_metadata.duration - - # precheck for max video length before actually download the video if possible - if video_duration and video_duration > settings.max_video_len: - raise VideoTooLongException( - video=video, - video_len=video_duration, - max_len=settings.max_video_len, + with get_session() as session: + ExtendedVideoRepository(session).save(video=video_obj, duration=video_duration) + + yield { + "media_id": media_id, + "metadata": VideoMetadata( + title=video_obj.title, + description=video_obj.description, + duration=video_duration, + ), + } + + try: + with get_session() as session: + ExtendedVideoRepository(session).update_status( + media_id, VideoProcessingStatus.RUNNING ) + audio: Audio = extract_audio(video=video_obj) + + # TODO: Update once batched whisper PR is merged + # vad_output = await self.vad_handle.asr_preprocess_vad( + # audio=audio, params=vad_params + # ) + # vad_segments = vad_output["segments"] + + transcription_list = [] + segments_list = [] + transcription_info_list = [] + async for whisper_output in self.asr_handle.transcribe_stream( + audio=audio, params=whisper_params + ): + transcription_list.append(whisper_output["transcription"]) + segments_list.append(whisper_output["segments"]) + transcription_info_list.append(whisper_output["transcription_info"]) + yield { + "transcription": whisper_output["transcription"], + "segments": whisper_output["segments"], + "transcription_info": whisper_output["transcription_info"], + } + transcription = sum(transcription_list, AsrTranscription()) + segments = sum(segments_list, AsrSegments()) + transcription_info = sum(transcription_info_list, AsrTranscriptionInfo()) + + captions = [] + timestamps = [] + frame_ids = [] + + async for frames_dict in run_remote(generate_frames)( + video=video_obj, params=video_params + ): + if len(frames_dict["frames"]) == 0: + break + + timestamps.extend(frames_dict["timestamps"]) + frame_ids.extend(frames_dict["frame_ids"]) + chat_prompt = "Describe the content of the following image in a single sentence:" + dialogs = [ + ImageChatDialog.from_prompt(prompt=chat_prompt, images=[frame]) for frame in frames_dict["frames"] + ] + + # Collect the tasks to run concurrently and wait for them to finish + tasks = [self.captioning_handle.chat(dialog) for dialog in dialogs] + captioning_output = await asyncio.gather(*tasks) + captioning_output = [caption["message"].content for caption in captioning_output] + captions.extend(captioning_output) - video_obj: Video = await run_remote(download_video)(video_input=video) - if video_duration is None: - video_duration = await run_remote(get_video_duration)(video=video_obj) - - if video_duration > settings.max_video_len: - raise VideoTooLongException( - video=video_obj, - video_len=video_duration, - max_len=settings.max_video_len, - ) + yield { + "captions": captioning_output, + "timestamps": frames_dict["timestamps"], + } - extended_video_repo.save(video=video_obj, duration=video_duration) - yield { - "media_id": media_id, - "metadata": VideoMetadata( - title=video_obj.title, - description=video_obj.description, - duration=video_duration, - ), - } - - try: - extended_video_repo.update_status( - media_id, VideoProcessingStatus.RUNNING - ) - audio: Audio = extract_audio(video=video_obj) - - # TODO: Update once batched whisper PR is merged - # vad_output = await self.vad_handle.asr_preprocess_vad( - # audio=audio, params=vad_params - # ) - # vad_segments = vad_output["segments"] - - transcription_list = [] - segments_list = [] - transcription_info_list = [] - async for whisper_output in self.asr_handle.transcribe_stream( - audio=audio, params=whisper_params - ): - transcription_list.append(whisper_output["transcription"]) - segments_list.append(whisper_output["segments"]) - transcription_info_list.append(whisper_output["transcription_info"]) - yield { - "transcription": whisper_output["transcription"], - "segments": whisper_output["segments"], - "transcription_info": whisper_output["transcription_info"], - } - transcription = sum(transcription_list, AsrTranscription()) - segments = sum(segments_list, AsrSegments()) - transcription_info = sum(transcription_info_list, AsrTranscriptionInfo()) - - captions = [] - timestamps = [] - frame_ids = [] - - async for frames_dict in run_remote(generate_frames)( - video=video_obj, params=video_params - ): - if len(frames_dict["frames"]) == 0: - break - - timestamps.extend(frames_dict["timestamps"]) - frame_ids.extend(frames_dict["frame_ids"]) - chat_prompt = "Describe the content of the following image in a single sentence:" - dialogs = [ - ImageChatDialog.from_prompt(prompt=chat_prompt, images=[frame]) for frame in frames_dict["frames"] - ] - - # Collect the tasks to run concurrently and wait for them to finish - tasks = [self.captioning_handle.chat(dialog) for dialog in dialogs] - captioning_output = await asyncio.gather(*tasks) - captioning_output = [caption["message"].content for caption in captioning_output] - captions.extend(captioning_output) - - yield { - "captions": captioning_output, - "timestamps": frames_dict["timestamps"], - } - - transcription_entity = transcript_repo.save( + with get_session() as session: + transcription_entity = ExtendedVideoTranscriptRepository(session).save( model_name=settings.asr_model_name, media_id=video_obj.media_id, transcription=transcription, @@ -186,7 +187,7 @@ async def run( # noqa: C901 transcription_info=transcription_info, ) - caption_entities = caption_repo.save_all( + caption_entities = ExtendedVideoCaptionRepository(session).save_all( model_name=settings.captioning_model_name, media_id=video_obj.media_id, captions=captions, @@ -198,12 +199,14 @@ async def run( # noqa: C901 "transcription_id": transcription_entity.id, "caption_ids": [c.id for c in caption_entities], } - except BaseException: - extended_video_repo.update_status( + except BaseException: + with get_session() as session: + ExtendedVideoRepository(session).update_status( media_id, VideoProcessingStatus.FAILED ) - raise - else: - extended_video_repo.update_status( + raise + else: + with get_session() as session: + ExtendedVideoRepository(session).update_status( media_id, VideoProcessingStatus.COMPLETED ) diff --git a/aana_chat_with_video/endpoints/load_video_metadata.py b/aana_chat_with_video/endpoints/load_video_metadata.py index a66b161..d7b3fac 100644 --- a/aana_chat_with_video/endpoints/load_video_metadata.py +++ b/aana_chat_with_video/endpoints/load_video_metadata.py @@ -18,10 +18,6 @@ class LoadVideoMetadataOutput(TypedDict): class LoadVideoMetadataEndpoint(Endpoint): """Load video metadata endpoint.""" - async def initialize(self): - """Initialize the endpoint.""" - await super().initialize() - async def run(self, media_id: MediaId) -> LoadVideoMetadataOutput: """Load video metadata.""" with get_session() as session: