Skip to content

Commit

Permalink
split subconversation into two functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bvreede committed Nov 28, 2023
1 parent 90c4ac6 commit 69912f6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 61 deletions.
78 changes: 39 additions & 39 deletions sktalk/corpus/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,57 +70,58 @@ def asdict(self):
"""
return self._metadata | {"Utterances": [u.asdict() for u in self._utterances]}

def _subconversation(self,
index: int,
before: int = 0,
after: Optional[int] = None,
exclude_utterance_overlap: bool = False,
time_or_index: str = "index") -> "Conversation":
def _subconversation_by_index(self,
index: int,
before: int = 0,
after: Optional[int] = None) -> "Conversation":
"""Select utterances to provide context as a sub-conversation
Args:
index (int): The index of the utterance for which to provide context
before (int, optional): Either the number of utterances prior to indicated utterance,
or the time in ms preceding the utterance's begin. Defaults to 0.
after (int, optional): Either the number of utterances after the indicated utterance,
or the time in ms following the utterance's end. Defaults to None,
which then assumes the same value as `before`.
exclude_utterance_overlap (bool, optional): Only used when `time_or_index` is "time",
and either `before` or `after` is 0. If True, the duration of the
utterance itself is not used to identify overlapping utterances, and only
the window before or after the utterance is used. Defaults to False.
time_or_index (str, optional): Use "time" to select based on time (in ms), or "index"
to select a set number of utterances irrespective of timing.
Defaults to "index".
before (int, optional): The number of utterances prior to indicated utterance. Defaults to 0.
after (int, optional): The number of utterances after the indicated utterance. Defaults to None,
which then assumes the same value as `before`.
Raises:
IndexError: Index provided must be within range of utterances
ValueError: time_or_index must be either "time" or "index"
Returns:
Conversation: Conversation object containing a reduced set of utterances
Conversation: Conversation object without metadata, containing a reduced set of utterances
"""
# TODO consider adding parameter 'strict' that only returns utterances entirely inside the window
if index < 0 or index >= len(self._utterances):
raise IndexError("Index out of range")
if after is None:
after = before
if time_or_index == "index":
# if before/after would exceed the bounds of the list, adjust
if index - before < 0:
before = index
if index + after + 1 > len(self._utterances):
after = len(self._utterances) - index - 1
returned_utterances = self._utterances[index-before:index+after+1]
elif time_or_index == "time":
returned_utterances = self._subconversation_by_time(
index, before, after, exclude_utterance_overlap)
else:
raise ValueError(
"`time_or_index` must be either 'time' or 'index'")
if index - before < 0:
before = index
if index + after + 1 > len(self._utterances):
after = len(self._utterances) - index - 1
returned_utterances = self._utterances[index-before:index+after+1]
return Conversation(utterances=returned_utterances, suppress_warnings=True)

def _subconversation_by_time(self, index, before, after, exclude_utterance_overlap):
def _subconversation_by_time(self,
index: int,
before: int = 0,
after: Optional[int] = None,
exclude_utterance_overlap: bool = False) -> "Conversation":
"""Select utterances to provide context as a sub-conversation
Args:
index (int): The index of the utterance for which to provide context
before (int, optional): The time in ms preceding the utterance's begin. Defaults to 0.
after (int, optional): The time in ms following the utterance's end. Defaults to None,
which then assumes the same value as `before`.
exclude_utterance_overlap (bool, optional): If True, the duration of the
utterance itself is not used to identify overlapping utterances, and only
the window before or after the utterance is used. Defaults to False.
Returns:
Conversation: Conversation object without metadata, containing a reduced set of utterances
"""
if index < 0 or index >= len(self._utterances):
raise IndexError("Index out of range")
if after is None:
after = before
try:
begin = self._utterances[index].time[0] - before
end = self._utterances[index].time[1] + after
Expand All @@ -131,8 +132,8 @@ def _subconversation_by_time(self, index, before, after, exclude_utterance_overl
returned_utterances = [
u for u in self._utterances if self.overlap(begin, end, u.time) or u == self._utterances[index]]
except (TypeError, IndexError):
return []
return returned_utterances
returned_utterances = []
return Conversation(utterances=returned_utterances, suppress_warnings=True)

def count_participants(self, except_none: bool = False) -> int:
"""Count the number of participants in a conversation
Expand Down Expand Up @@ -203,9 +204,8 @@ def calculate_FTO(self, window: int = 10000, planning_buffer: int = 200, n_parti
"""
values = []
for index, utterance in enumerate(self.utterances):
sub = self._subconversation(
sub = self._subconversation_by_time(
index=index,
time_or_index="time",
before=window,
after=0,
exclude_utterance_overlap=True)
Expand Down
51 changes: 29 additions & 22 deletions tests/corpus/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,43 @@ def test_write_json(self, convo, tmp_path, user_path, expected_path):
class TestConversationMetrics:
@pytest.mark.parametrize("args, error",
[
([0, 0, 1, "index"], does_not_raise()),
([20, 1, 1, "index"], pytest.raises(IndexError)),
([0, 50, 50, "index"], does_not_raise()),
([0, 0, 0, "neither_time_nor_index"],
pytest.raises(ValueError))
([0, 0, 1], does_not_raise()),
([20, 1, 1], pytest.raises(IndexError)),
([0, 50, 50], does_not_raise())
])
def test_subconversation_errors(self, convo, args, error):
index, before, after, time_or_index = args
index, before, after = args
with error:
convo._subconversation(index=index, # noqa W0212
convo._subconversation_by_index(index=index, # noqa W0212
before=before,
after=after,
time_or_index=time_or_index)
after=after)

@pytest.mark.parametrize("args, expected_length",
[
([0, 0, 1, "index"], 2),
([5, 2, 0, "index"], 3),
([0, 2, 2, "index"], 3),
([0, 2, None, "index"], 3),
([0, 0, 0, "time"], 2), # A, B
([5, 3000, 3000, "time"], 6), # B,C,E,U,F,H
([5, 0, 0, "time"], 3), # C, U, F
([0, 0, 1], 2),
([5, 2, 0], 3),
([0, 2, 2], 3),
([0, 2, None], 3)
])
def test_subconversation_index(self, convo, args, expected_length):
index, before, after = args
sub = convo._subconversation_by_index(index=index, # noqa W0212
before=before,
after=after)
assert isinstance(sub, Conversation)
assert len(sub.utterances) == expected_length

@pytest.mark.parametrize("args, expected_length",
[
([0, 0, 0], 2), # A, B
([5, 3000, 3000], 6), # B,C,E,U,F,H
([5, 0, 0], 3), # C, U, F
])
def test_subconversation(self, convo, args, expected_length):
index, before, after, time_or_index = args
sub = convo._subconversation(index=index, # noqa W0212
index, before, after = args
sub = convo._subconversation_by_time(index=index, # noqa W0212
before=before,
after=after,
time_or_index=time_or_index)
after=after)
assert isinstance(sub, Conversation)
assert len(sub.utterances) == expected_length

Expand All @@ -102,9 +109,9 @@ def test_overlap(self):
def test_count_participants(self, convo):
assert convo.count_participants() == 4
assert convo.count_participants(except_none=True) == 3
convo2 = convo._subconversation(index=0, before=2) # noqa W0212
convo2 = convo._subconversation_by_index(index=0, before=2) # noqa W0212
assert convo2.count_participants() == 2
convo3 = convo._subconversation(index=0) # noqa W0212
convo3 = convo._subconversation_by_index(index=0) # noqa W0212
assert convo3.count_participants() == 1

def test_calculate_FTO(self, convo):
Expand Down

0 comments on commit 69912f6

Please sign in to comment.