Skip to content

Commit

Permalink
Allow user to pass in a customized speaker selection method (#1791)
Browse files Browse the repository at this point in the history
* init PR

* update

* update code check

* update

* update

* update

* update

* Test the ability to have agents a,u,t,o,g,e,n speak in turn.

* update

* update

* update

* Evidence that groupchat not terminating because of the TERMINATE substring.

* Raising NoEligibleSpeakerException allows graceful exit before max turns

* update

* To confirm with author that custom function is meant to override graph constraints

* Confirmed the expected test behaviour with author

* Update autogen/agentchat/groupchat.py

* update

* update

---------

Co-authored-by: Joshua Kim <[email protected]>
Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
3 people authored Mar 7, 2024
1 parent d711bd8 commit c37227b
Show file tree
Hide file tree
Showing 6 changed files with 707 additions and 17 deletions.
54 changes: 44 additions & 10 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple
from typing import Dict, List, Optional, Union, Tuple, Callable


from ..code_utils import content_str
Expand Down Expand Up @@ -42,7 +42,16 @@ class GroupChat:
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- a customized speaker selection function (Callable): the function will be called to select the next speaker.
The function should take the last speaker and the group chat as input and return one of the following:
1. an `Agent` class, it must be one of the agents in the group chat.
2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
3. None, which would terminate the conversation gracefully.
```python
def custom_speaker_selection_func(
last_speaker: Agent, groupchat: GroupChat
) -> Union[Agent, str, None]:
```
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
Default is True, in which case all speakers are allowed to speak consecutively.
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
Expand All @@ -67,7 +76,7 @@ class GroupChat:
max_round: Optional[int] = 10
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[str] = "auto"
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
Expand Down Expand Up @@ -277,11 +286,36 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
return random.choice(agents)

def _prepare_and_select_agents(
self, last_speaker: Agent
self,
last_speaker: Agent,
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
# If self.speaker_selection_method is a callable, call it to get the next speaker.
# If self.speaker_selection_method is a string, return it.
speaker_selection_method = self.speaker_selection_method
if isinstance(self.speaker_selection_method, Callable):
selected_agent = self.speaker_selection_method(last_speaker, self)
if selected_agent is None:
raise NoEligibleSpeakerException(
"Custom speaker selection function returned None. Terminating conversation."
)
elif isinstance(selected_agent, Agent):
if selected_agent in self.agents:
return selected_agent, self.agents, None
else:
raise ValueError(
f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat."
)
elif isinstance(selected_agent, str):
# If returned a string, assume it is a speaker selection method
speaker_selection_method = selected_agent
else:
raise ValueError(
f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str."
)

if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)

Expand All @@ -300,7 +334,7 @@ def _prepare_and_select_agents(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
Expand Down Expand Up @@ -366,11 +400,11 @@ def _prepare_and_select_agents(

# Use the selected speaker selection method
select_speaker_messages = None
if self.speaker_selection_method.lower() == "manual":
if speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(graph_eligible_agents)
elif self.speaker_selection_method.lower() == "round_robin":
elif speaker_selection_method.lower() == "round_robin":
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
elif self.speaker_selection_method.lower() == "random":
elif speaker_selection_method.lower() == "random":
selected_agent = self.random_select_speaker(graph_eligible_agents)
else:
selected_agent = None
Expand Down
1 change: 1 addition & 0 deletions notebook/agentchat_custom_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
"source": [
"# load model here\n",
"\n",
"\n",
"config = config_list_custom[0]\n",
"device = config.get(\"device\", \"cpu\")\n",
"loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n",
Expand Down
Loading

0 comments on commit c37227b

Please sign in to comment.