diff --git a/sktalk/corpus/conversation.py b/sktalk/corpus/conversation.py index 0b90abf..2af1580 100644 --- a/sktalk/corpus/conversation.py +++ b/sktalk/corpus/conversation.py @@ -120,25 +120,17 @@ def _subconversation(self, "`time_or_index` must be either 'time' or 'index'") return Conversation(utterances=returned_utterances) - # TODO should metadata be part of this? - return Conversation(returned_utterances, self.metadata) - - def _time_to_next(self) -> int: - # if len(self.utterances) != 2: - # return None - try: - return self.utterances[0].until(self.utterances[1]) - except (TypeError, IndexError): - return None - - def _dyadic(self) -> bool: - participants = [u.participant for u in self.utterances] - return len(set(participants)) == 2 + def _count_participants(self) -> int: + """Count the number of participants in a conversation + + Importantly: if one of the utterances has no participant, it is counted + as a separate participant (None). - CONVERSATION_FUNCTIONS = { - "dyadic": _dyadic, - "time_to_next": _time_to_next, - } + Returns: + int: number of participants + """ + participants = [u.participant for u in self.utterances] + return len(set(participants)) def apply(self, field, **kwargs): """ diff --git a/tests/corpus/test_conversation.py b/tests/corpus/test_conversation.py index e90a284..b4cf81b 100644 --- a/tests/corpus/test_conversation.py +++ b/tests/corpus/test_conversation.py @@ -99,15 +99,9 @@ def test_overlap(self): assert not Conversation.overlap( 70, 80, [90, 110]) # utterance after window - def test_dyadic(self, convo): - assert not convo._dyadic() - convo2 = convo.subconversation(0, 2) - assert convo2._dyadic() - convo3 = convo.subconversation(0) - assert not convo3._dyadic() - - def test_apply_dyadic(self, convo): - convo.apply("dyadic", before=1) - assert convo.utterances[0].dyadic - assert convo.utterances[2].dyadic - assert not convo.utterances[8].dyadic + def test_count_participants(self, convo): + assert convo._count_participants() == 3 + convo2 = convo._subconversation(index=0, before=2) + assert convo2._count_participants() == 2 + convo3 = convo._subconversation(index=0) + assert convo3._count_participants() == 1