Skip to content

Commit

Permalink
PR suggestion 1
Browse files Browse the repository at this point in the history
  • Loading branch information
HRashidi committed Oct 31, 2024
1 parent f29850a commit 2d8aae4
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 118 deletions.
7 changes: 3 additions & 4 deletions aana_chat_with_video/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from aana.configs.settings import settings
from aana.storage.models.base import BaseEntity
# Import all models to be included in the migration
import aana.storage.models # noqa: F401
import aana_chat_with_video.storage.models # noqa: F401

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand All @@ -20,10 +23,6 @@
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata

# Import all models to be included in the migration
import aana.storage.models # noqa: F401
import aana_chat_with_video.storage.models # noqa: F401

target_metadata = BaseEntity.metadata

# other values from the config, defined by the needs of env.py,
Expand Down
4 changes: 0 additions & 4 deletions aana_chat_with_video/endpoints/delete_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class DeleteVideoOutput(TypedDict):
class DeleteVideoEndpoint(Endpoint):
"""Delete video endpoint."""

async def initialize(self):
"""Initialize the endpoint."""
await super().initialize()

async def run(self, media_id: MediaId) -> DeleteVideoOutput:
"""Delete video."""
with get_session() as session:
Expand Down
9 changes: 3 additions & 6 deletions aana_chat_with_video/endpoints/get_video_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aana.api.api_generation import Endpoint
from aana.core.models.media import MediaId
from aana.storage.session import get_session
from aana_chat_with_video.core.models.video_status import VideoStatus
from aana_chat_with_video.storage.repository.extended_video import (
ExtendedVideoRepository,
Expand All @@ -17,12 +18,8 @@ class VideoStatusOutput(TypedDict):
class GetVideoStatusEndpoint(Endpoint):
"""Get video status endpoint."""

async def initialize(self):
"""Initialize the endpoint."""
await super().initialize()
self.video_repo = ExtendedVideoRepository(self.session)

async def run(self, media_id: MediaId) -> VideoStatusOutput:
"""Load video metadata."""
video_status = self.video_repo.get_status(media_id)
with get_session() as session:
video_status = ExtendedVideoRepository(session).get_status(media_id)
return VideoStatusOutput(status=video_status)
203 changes: 103 additions & 100 deletions aana_chat_with_video/endpoints/index_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,113 +80,114 @@ async def run( # noqa: C901
"""Transcribe video in chunks."""
media_id = video.media_id
with get_session() as session:
extended_video_repo = ExtendedVideoRepository(session)
transcript_repo = ExtendedVideoTranscriptRepository(session)
caption_repo = ExtendedVideoCaptionRepository(session)
if extended_video_repo.check_media_exists(media_id):
if ExtendedVideoRepository(session).check_media_exists(media_id):
raise MediaIdAlreadyExistsException(table_name="media", media_id=video)

video_duration = None
if video.url is not None:
video_metadata = get_video_metadata(video.url)
video_duration = video_metadata.duration

# precheck for max video length before actually download the video if possible
if video_duration and video_duration > settings.max_video_len:
raise VideoTooLongException(
video=video,
video_len=video_duration,
max_len=settings.max_video_len,
)

video_obj: Video = await run_remote(download_video)(video_input=video)
if video_duration is None:
video_duration = await run_remote(get_video_duration)(video=video_obj)

if video_duration > settings.max_video_len:
raise VideoTooLongException(
video=video_obj,
video_len=video_duration,
max_len=settings.max_video_len,
)

video_duration = None
if video.url is not None:
video_metadata = get_video_metadata(video.url)
video_duration = video_metadata.duration

# precheck for max video length before actually download the video if possible
if video_duration and video_duration > settings.max_video_len:
raise VideoTooLongException(
video=video,
video_len=video_duration,
max_len=settings.max_video_len,
with get_session() as session:
ExtendedVideoRepository(session).save(video=video_obj, duration=video_duration)

yield {
"media_id": media_id,
"metadata": VideoMetadata(
title=video_obj.title,
description=video_obj.description,
duration=video_duration,
),
}

try:
with get_session() as session:
ExtendedVideoRepository(session).update_status(
media_id, VideoProcessingStatus.RUNNING
)
audio: Audio = extract_audio(video=video_obj)

# TODO: Update once batched whisper PR is merged
# vad_output = await self.vad_handle.asr_preprocess_vad(
# audio=audio, params=vad_params
# )
# vad_segments = vad_output["segments"]

transcription_list = []
segments_list = []
transcription_info_list = []
async for whisper_output in self.asr_handle.transcribe_stream(
audio=audio, params=whisper_params
):
transcription_list.append(whisper_output["transcription"])
segments_list.append(whisper_output["segments"])
transcription_info_list.append(whisper_output["transcription_info"])
yield {
"transcription": whisper_output["transcription"],
"segments": whisper_output["segments"],
"transcription_info": whisper_output["transcription_info"],
}
transcription = sum(transcription_list, AsrTranscription())
segments = sum(segments_list, AsrSegments())
transcription_info = sum(transcription_info_list, AsrTranscriptionInfo())

captions = []
timestamps = []
frame_ids = []

async for frames_dict in run_remote(generate_frames)(
video=video_obj, params=video_params
):
if len(frames_dict["frames"]) == 0:
break

timestamps.extend(frames_dict["timestamps"])
frame_ids.extend(frames_dict["frame_ids"])
chat_prompt = "Describe the content of the following image in a single sentence:"
dialogs = [
ImageChatDialog.from_prompt(prompt=chat_prompt, images=[frame]) for frame in frames_dict["frames"]
]

# Collect the tasks to run concurrently and wait for them to finish
tasks = [self.captioning_handle.chat(dialog) for dialog in dialogs]
captioning_output = await asyncio.gather(*tasks)
captioning_output = [caption["message"].content for caption in captioning_output]
captions.extend(captioning_output)

video_obj: Video = await run_remote(download_video)(video_input=video)
if video_duration is None:
video_duration = await run_remote(get_video_duration)(video=video_obj)

if video_duration > settings.max_video_len:
raise VideoTooLongException(
video=video_obj,
video_len=video_duration,
max_len=settings.max_video_len,
)
yield {
"captions": captioning_output,
"timestamps": frames_dict["timestamps"],
}

extended_video_repo.save(video=video_obj, duration=video_duration)
yield {
"media_id": media_id,
"metadata": VideoMetadata(
title=video_obj.title,
description=video_obj.description,
duration=video_duration,
),
}

try:
extended_video_repo.update_status(
media_id, VideoProcessingStatus.RUNNING
)
audio: Audio = extract_audio(video=video_obj)

# TODO: Update once batched whisper PR is merged
# vad_output = await self.vad_handle.asr_preprocess_vad(
# audio=audio, params=vad_params
# )
# vad_segments = vad_output["segments"]

transcription_list = []
segments_list = []
transcription_info_list = []
async for whisper_output in self.asr_handle.transcribe_stream(
audio=audio, params=whisper_params
):
transcription_list.append(whisper_output["transcription"])
segments_list.append(whisper_output["segments"])
transcription_info_list.append(whisper_output["transcription_info"])
yield {
"transcription": whisper_output["transcription"],
"segments": whisper_output["segments"],
"transcription_info": whisper_output["transcription_info"],
}
transcription = sum(transcription_list, AsrTranscription())
segments = sum(segments_list, AsrSegments())
transcription_info = sum(transcription_info_list, AsrTranscriptionInfo())

captions = []
timestamps = []
frame_ids = []

async for frames_dict in run_remote(generate_frames)(
video=video_obj, params=video_params
):
if len(frames_dict["frames"]) == 0:
break

timestamps.extend(frames_dict["timestamps"])
frame_ids.extend(frames_dict["frame_ids"])
chat_prompt = "Describe the content of the following image in a single sentence:"
dialogs = [
ImageChatDialog.from_prompt(prompt=chat_prompt, images=[frame]) for frame in frames_dict["frames"]
]

# Collect the tasks to run concurrently and wait for them to finish
tasks = [self.captioning_handle.chat(dialog) for dialog in dialogs]
captioning_output = await asyncio.gather(*tasks)
captioning_output = [caption["message"].content for caption in captioning_output]
captions.extend(captioning_output)

yield {
"captions": captioning_output,
"timestamps": frames_dict["timestamps"],
}

transcription_entity = transcript_repo.save(
with get_session() as session:
transcription_entity = ExtendedVideoTranscriptRepository(session).save(
model_name=settings.asr_model_name,
media_id=video_obj.media_id,
transcription=transcription,
segments=segments,
transcription_info=transcription_info,
)

caption_entities = caption_repo.save_all(
caption_entities = ExtendedVideoCaptionRepository(session).save_all(
model_name=settings.captioning_model_name,
media_id=video_obj.media_id,
captions=captions,
Expand All @@ -198,12 +199,14 @@ async def run( # noqa: C901
"transcription_id": transcription_entity.id,
"caption_ids": [c.id for c in caption_entities],
}
except BaseException:
extended_video_repo.update_status(
except BaseException:
with get_session() as session:
ExtendedVideoRepository(session).update_status(
media_id, VideoProcessingStatus.FAILED
)
raise
else:
extended_video_repo.update_status(
raise
else:
with get_session() as session:
ExtendedVideoRepository(session).update_status(
media_id, VideoProcessingStatus.COMPLETED
)
4 changes: 0 additions & 4 deletions aana_chat_with_video/endpoints/load_video_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ class LoadVideoMetadataOutput(TypedDict):
class LoadVideoMetadataEndpoint(Endpoint):
"""Load video metadata endpoint."""

async def initialize(self):
"""Initialize the endpoint."""
await super().initialize()

async def run(self, media_id: MediaId) -> LoadVideoMetadataOutput:
"""Load video metadata."""
with get_session() as session:
Expand Down

0 comments on commit 2d8aae4

Please sign in to comment.