Skip to content

Commit

Permalink
ensure participant count does not include future utterances
Browse files Browse the repository at this point in the history
  • Loading branch information
bvreede committed Nov 24, 2023
1 parent 00ad82a commit 90c4ac6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
45 changes: 32 additions & 13 deletions sktalk/corpus/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _subconversation(self,
index: int,
before: int = 0,
after: Optional[int] = None,
exclude_utterance_overlap: bool = False,
time_or_index: str = "index") -> "Conversation":
"""Select utterances to provide context as a sub-conversation
Expand All @@ -84,6 +85,10 @@ def _subconversation(self,
after (int, optional): Either the number of utterances after the indicated utterance,
or the time in ms following the utterance's end. Defaults to None,
which then assumes the same value as `before`.
exclude_utterance_overlap (bool, optional): Only used when `time_or_index` is "time",
and either `before` or `after` is 0. If True, the duration of the
utterance itself is not used to identify overlapping utterances, and only
the window before or after the utterance is used. Defaults to False.
time_or_index (str, optional): Use "time" to select based on time (in ms), or "index"
to select a set number of utterances irrespective of timing.
Defaults to "index".
Expand All @@ -108,17 +113,26 @@ def _subconversation(self,
after = len(self._utterances) - index - 1
returned_utterances = self._utterances[index-before:index+after+1]
elif time_or_index == "time":
try:
begin = self._utterances[index].time[0] - before
end = self._utterances[index].time[1] + after
returned_utterances = [
u for u in self._utterances if self.overlap(begin, end, u.time)]
except (TypeError, IndexError):
return Conversation([], suppress_warnings=True)
returned_utterances = self._subconversation_by_time(
index, before, after, exclude_utterance_overlap)
else:
raise ValueError(
"`time_or_index` must be either 'time' or 'index'")
return Conversation(utterances=returned_utterances)
return Conversation(utterances=returned_utterances, suppress_warnings=True)

def _subconversation_by_time(self, index, before, after, exclude_utterance_overlap):
try:
begin = self._utterances[index].time[0] - before
end = self._utterances[index].time[1] + after
if exclude_utterance_overlap and before == 0: # only overlap with window following utterance
begin = self._utterances[index].time[1]
elif exclude_utterance_overlap and after == 0: # only overlap with window preceding utterance
end = self._utterances[index].time[0]
returned_utterances = [
u for u in self._utterances if self.overlap(begin, end, u.time) or u == self._utterances[index]]
except (TypeError, IndexError):
return []
return returned_utterances

def count_participants(self, except_none: bool = False) -> int:
"""Count the number of participants in a conversation
Expand Down Expand Up @@ -175,21 +189,26 @@ def calculate_FTO(self, window: int = 10000, planning_buffer: int = 200, n_parti
- the utterance must be by another speaker than U
- the utterance by the other speaker must be the most recent utterance by that speaker
- the utterance must have started before utterance U, more than `planning_buffer` ms before.
- the utterance must be partly or entirely within the context window (`window` ms prior to the start of utterance U)
- the utterance must be partly or entirely within the context window (`window` ms prior
to the start of utterance U)
- within the context window, there must be a maximum of `n_participants` speakers.
Args:
window (int, optional): _description_. Defaults to 10000.
planning_buffer (int, optional): _description_. Defaults to 200.
n_participants (int, optional): _description_. Defaults to 2.
window (int, optional): the time in ms prior to utterance in which a
relevant preceding utterance can be found. Defaults to 10000.
planning_buffer (int, optional): minimum speaking time in ms to allow for a response.
Defaults to 200.
n_participants (int, optional): maximum number of participants overlapping with
the utterance and preceding window. Defaults to 2.
"""
values = []
for index, utterance in enumerate(self.utterances):
sub = self._subconversation(
index=index,
time_or_index="time",
before=window,
after=0)
after=0,
exclude_utterance_overlap=True)
if not 2 <= sub.count_participants() <= n_participants:
values.append(None)
continue
Expand Down
22 changes: 11 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,52 +26,52 @@ def convo_meta():
def convo_utts():
return [
Utterance(
utterance="Hello A",
utterance="0 utterance A",
participant="A",
time=[0, 1000]
),
Utterance(
utterance="Monde B",
utterance="1 utterance B",
participant="B",
time=[900, 3500]
),
Utterance(
utterance="Hello C",
utterance="2 utterance C",
participant="A",
time=[1001, 12000]
),
Utterance(
utterance="Monde D",
utterance="3 utterance D",
participant="B",
time=[1200, 1999]
),
Utterance(
utterance="Hello E",
utterance="4 utterance E",
participant="A",
time=[3500, 4500]
),
Utterance(
utterance="Utterance U",
utterance="5 utterance U",
participant="B",
time=[5000, 8000]
),
Utterance(
utterance="Monde F",
utterance="6 utterance F",
participant="C",
time=[5500, 7500]
),
Utterance(
utterance="Hello G",
participant="A",
utterance="7 utterance G",
participant=None,
time=None
),
Utterance(
utterance="Monde H",
utterance="8 utterance H",
participant="B",
time=[9000, 12500]
),
Utterance(
utterance="Hello I",
utterance="9 utterance I",
participant="C",
time=[12000, 13000]
)
Expand Down
6 changes: 5 additions & 1 deletion tests/corpus/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def test_overlap(self):
70, 80, [90, 110]) # utterance after window

def test_count_participants(self, convo):
assert convo.count_participants() == 3
assert convo.count_participants() == 4
assert convo.count_participants(except_none=True) == 3
convo2 = convo._subconversation(index=0, before=2) # noqa W0212
assert convo2.count_participants() == 2
convo3 = convo._subconversation(index=0) # noqa W0212
Expand All @@ -115,3 +116,6 @@ def test_calculate_FTO(self, convo):
"window": 10, "planning_buffer": 200, "n_participants": 2}
assert convo.utterances[0].FTO is None
assert convo.utterances[1].FTO == -100
assert convo.utterances[2].FTO == None
convo.calculate_FTO(planning_buffer=0)
assert convo.utterances[2].FTO == -2499

0 comments on commit 90c4ac6

Please sign in to comment.