Skip to content

Commit

Permalink
Merge pull request #291 from clamsproject/281-improved-rep-frame-extr
Browse files Browse the repository at this point in the history
improving representative frame extraction
  • Loading branch information
keighrim authored Aug 21, 2024
2 parents af83e13 + ab40d2b commit e776810
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
52 changes: 39 additions & 13 deletions mmif/utils/video_document_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def extract_frames_as_images(video_document: Document, framenums: List[int], as_
return frames


def get_mid_framenum(mmif: Mmif, time_frame: Annotation):
def get_mid_framenum(mmif: Mmif, time_frame: Annotation) -> int:
warnings.warn('This function is deprecated. Use ``get_representative_framenums()`` instead.', DeprecationWarning, stacklevel=2)
return _get_mid_framenum(mmif, time_frame)


def _get_mid_framenum(mmif: Mmif, time_frame: Annotation) -> int:
"""
Calculates the middle frame number of a time interval annotation.
Expand All @@ -112,7 +117,7 @@ def get_mid_framenum(mmif: Mmif, time_frame: Annotation):
timeunit = time_frame.get_property('timeUnit')
video_document = mmif[time_frame.get_property('document')]
fps = get_framerate(video_document)
return convert(time_frame.get_property('start') + time_frame.get_property('end'), timeunit, 'frame', fps) // 2
return int(convert(time_frame.get_property('start') + time_frame.get_property('end'), timeunit, 'frame', fps) // 2)


def extract_mid_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False):
Expand All @@ -124,44 +129,65 @@ def extract_mid_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False):
:param as_PIL: return :py:class:`~PIL.Image.Image` instead of :py:class:`~numpy.ndarray`
:return: frame as a :py:class:`numpy.ndarray` or :py:class:`PIL.Image.Image`
"""
warnings.warn('This function is deprecated. Use ``extract_representative_frames()`` instead.', DeprecationWarning, stacklevel=2)
vd = mmif[time_frame.get_property('document')]
return extract_frames_as_images(vd, [get_mid_framenum(mmif, time_frame)], as_PIL=as_PIL)[0]


def get_representative_framenum(mmif: Mmif, time_frame: Annotation):
def get_representative_framenums(mmif: Mmif, time_frame: Annotation) -> List[int]:
"""
Calculates the representative frame number from an annotation.
Calculates the representative frame numbers from an annotation. To pick the representative frames, it first looks
up the ``representatives`` property of the ``TimeFrame`` annotation. If it is not found, it will calculate the
number of the middle frame.
:param mmif: :py:class:`~mmif.serialize.mmif.Mmif` instance
:param time_frame: :py:class:`~mmif.serialize.annotation.Annotation` instance that holds a time interval annotation containing a `representatives` property (``"@type": ".../TimeFrame/..."``)
:return: representative frame number as an integer
"""
if 'representatives' not in time_frame.properties:
raise ValueError(f'The time frame {time_frame.id} does not have a representative.')
return [_get_mid_framenum(mmif, time_frame)]
timeunit = time_frame.get_property('timeUnit')
video_document = mmif[time_frame.get_property('document')]
fps = get_framerate(video_document)
representatives = time_frame.get_property('representatives')
top_representative_id = representatives[0]
ref_frams = []
for rep in representatives:
if Mmif.id_delimiter in rep:
rep_long_id = rep
else:
rep_long_id = time_frame._parent_view_id+time_frame.id_delimiter+rep
try:
rep_anno = mmif[rep_long_id]
except KeyError:
raise ValueError(f'Representative timepoint {rep_long_id} not found in any view.')
ref_frams.append(int(convert(rep_anno.get_property('timePoint'), timeunit, 'frame', fps)))
return ref_frams


def get_representative_framenum(mmif: Mmif, time_frame: Annotation) -> int:
"""
A thin wrapper around :py:func:`get_representative_framenums` to return a single representative frame number. Always
return the first frame number found.
"""
try:
representative_timepoint_anno = mmif[time_frame._parent_view_id+time_frame.id_delimiter+top_representative_id]
except KeyError:
raise ValueError(f'Representative timepoint {top_representative_id} not found in any view.')
return convert(representative_timepoint_anno.get_property('timePoint'), timeunit, 'frame', fps)
return get_representative_framenums(mmif, time_frame)[0]
except IndexError:
raise ValueError(f'No representative frame found in the TimeFrame annotation {time_frame.id}.')


def extract_representative_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False):
def extract_representative_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False, first_only: bool = True):
"""
Extracts the representative frame of an annotation as a numpy ndarray or PIL Image.
:param mmif: :py:class:`~mmif.serialize.mmif.Mmif` instance
:param time_frame: :py:class:`~mmif.serialize.annotation.Annotation` instance that holds a time interval annotation (``"@type": ".../TimeFrame/..."``)
:param as_PIL: return :py:class:`~PIL.Image.Image` instead of :py:class:`~numpy.ndarray`
:param first_only: return the first representative frame only
:return: frame as a :py:class:`numpy.ndarray` or :py:class:`PIL.Image.Image`
"""
video_document = mmif[time_frame.get_property('document')]
rep_frame_num = get_representative_framenum(mmif, time_frame)
return extract_frames_as_images(video_document, [rep_frame_num], as_PIL=as_PIL)[0]
rep_frame_num = [get_representative_framenum(mmif, time_frame)] if first_only else get_representative_framenums(mmif, time_frame)
return extract_frames_as_images(video_document, rep_frame_num, as_PIL=as_PIL)[0]


def sample_frames(start_frame: int, end_frame: int, sample_rate: float = 1) -> List[int]:
Expand Down
8 changes: 3 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def test_extract_representative_frame(self):
rep_frame_num = vdh.get_representative_framenum(self.mmif_obj, tf)
expected_frame_num = vdh.millisecond_to_framenum(self.video_doc, tp.get_property('timePoint'))
self.assertEqual(expected_frame_num, rep_frame_num)
# check there is an error if no representatives
# and should work even if no representatives are provided
tf = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=1000, end=2000, timeUnit='milliseconds', document='d1')
with pytest.raises(ValueError):
vdh.get_representative_framenum(self.mmif_obj, tf)
# check there is an error if there is a representative referencing a timepoint that
# does not exist
self.assertEqual(vdh.get_representative_framenum(self.mmif_obj, tf), vdh.get_mid_framenum(self.mmif_obj, tf))
# check there is an error if there is a representative referencing a timepoint that does not exist
tf.add_property('representatives', ['fake_tp_id'])
with pytest.raises(ValueError):
vdh.get_representative_framenum(self.mmif_obj, tf)
Expand Down

0 comments on commit e776810

Please sign in to comment.