Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow models to return non-json / unenforced responses #12

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions narrative_llm_tools/handlers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
class HandlerResponse(BaseModel):
"""Response from the handler."""

tool_calls: list[dict[str, Any]]
warnings: list[str] | None
text_response: str | None = None
tool_calls: list[dict[str, Any]] | None = None
warnings: list[str] | None = None


class ModelConfig(BaseModel):
Expand Down Expand Up @@ -285,18 +286,25 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
}:
self._process_conversation_turn(conversation_state)

return_msg = json.loads(conversation_state.get_last_message().content)
if conversation_state.tool_choice != "none":
return_msg = json.loads(conversation_state.get_last_message().content)

if not isinstance(return_msg, list):
raise ModelOutputError("Model output is not a list of tool calls.")

for tool_call in return_msg:
if not isinstance(tool_call, dict):
if not isinstance(return_msg, list):
raise ModelOutputError("Model output is not a list of tool calls.")

return HandlerResponse(tool_calls=return_msg, warnings=None).model_dump(
exclude_none=True
)
for tool_call in return_msg:
if not isinstance(tool_call, dict):
raise ModelOutputError("Model output is not a list of tool calls.")

return HandlerResponse(
tool_calls=return_msg, warnings=None, text_response=None
).model_dump(exclude_none=True)
else:
return HandlerResponse(
tool_calls=None,
text_response=conversation_state.get_last_message().content,
warnings=None,
).model_dump(exclude_none=True)

except (
ValidationError,
Expand All @@ -316,17 +324,42 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
def _process_conversation_turn(self, state: ConversationState) -> None:
"""Process a single turn of the conversation."""
conversation_text = self._format_conversation(state)
format_enforcer = get_format_enforcer(self.pipeline.tokenizer, state.update_current_tools())
format_enforcer = self._get_format_enforcer(state)

model_output = self._generate_prediction(
conversation_text, format_enforcer, state.pipeline_params
)

tool_calls = self._format_model_output(model_output)
serialized = [tool.model_dump() for tool in tool_calls]
state.add_message(ConversationMessage(role="tool_calls", content=json.dumps(serialized)))
formatted_output = self._format_model_output(model_output, state.tool_choice)

if state.tool_choice != "none":
if not isinstance(formatted_output, list):
logger.warning("Expected list of tool calls but got different type")
return

serialized = [tool.model_dump() for tool in formatted_output]
state.add_message(
ConversationMessage(role="tool_calls", content=json.dumps(serialized))
)

if state.only_called_rest_api_tools(tool_calls):
self._execute_tool_calls(tool_calls, state)
if state.only_called_rest_api_tools(formatted_output):
self._execute_tool_calls(formatted_output, state)
else:
if not isinstance(formatted_output, str):
logger.warning("Expected string response but got different type")
return

state.add_message(ConversationMessage(role="assistant", content=formatted_output))

def _get_format_enforcer(self, state: ConversationState) -> FormatEnforcer | None:
"""Get the format enforcer based on current tools state."""
if not state.tools_catalog:
return None

current_tools = state.update_current_tools()
return (
get_format_enforcer(self.pipeline.tokenizer, current_tools) if current_tools else None
)

def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState) -> None:
"""Execute tool calls and update conversation state."""
Expand Down Expand Up @@ -400,7 +433,11 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
if return_to_user and state.status != ConversationStatus.COMPLETED:
state.transition_to(ConversationStatus.COMPLETED)

def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool]:
def _format_model_output(
self,
model_output: list[dict[str, Any]],
tool_choice: Literal["required", "none", "auto"],
) -> list[Tool] | str:
"""Format the model output into a list of dictionaries."""
if not model_output:
return []
Expand All @@ -413,6 +450,9 @@ def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool]
if generated_text is None:
raise ModelOutputError("No generated_text found in the model output.")

if tool_choice == "none":
return generated_text

try:
logger.debug(f"Generated text: {generated_text}")
parsed_output: list[Tool] = [
Expand Down
48 changes: 35 additions & 13 deletions narrative_llm_tools/state/conversation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ConversationState(BaseModel):
raw_messages: list[ConversationMessage]
max_tool_rounds: int = 5
tool_choice: Literal["required", "auto", "none"] = "required"
tools_catalog: JsonSchemaTools = JsonSchemaTools.only_user_response_tool()
tools_catalog: JsonSchemaTools | None = JsonSchemaTools.only_user_response_tool()
pipeline_params: dict[str, Any]
status: ConversationStatus = ConversationStatus.RUNNING

Expand All @@ -103,14 +103,14 @@ def from_api_request(cls, request_data: dict[str, Any]) -> "ConversationState":
if k not in cls.RESERVED_KEYS and not k.startswith("_")
}

tool_choice = request_data.get("tool_choice", "required")
tools_data = request_data.get("tools", {})
tools_instance = (
JsonSchemaTools.model_validate(tools_data)
if tools_data
else JsonSchemaTools.only_user_response_tool()
if tools_data and tools_data != {} and tool_choice != "none"
else JsonSchemaTools.only_user_response_tool() if tool_choice != "none" else None
)

tool_choice = request_data.get("tool_choice", "required")
status = (
ConversationStatus.WRAP_THINGS_UP
if tool_choice == "none"
Expand Down Expand Up @@ -188,7 +188,9 @@ def _has_non_rest_tool(self) -> bool:
"""
Internal helper to check if there's at least one non-REST API tool available.
"""
return len(self.rest_api_names) != len(self.tools_catalog.items.anyOf)
return not self.tools_catalog or len(self.rest_api_names) != len(
self.tools_catalog.items.anyOf
)

def _has_rest_api_tools(self, content: str) -> bool:
"""Checks if the given content calls any REST API tools."""
Expand Down Expand Up @@ -235,13 +237,14 @@ def can_respond(self) -> bool:

def get_rest_api_catalog(self) -> dict[str, RestApiClient]:
"""Returns all REST API tools from the current catalog."""
return self.tools_catalog.get_rest_apis()
return self.tools_catalog.get_rest_apis() if self.tools_catalog else {}

def remove_tool(self, tool_name: str) -> None:
"""
Removes the specified tool from the catalog if it exists.
"""
self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name)
if self.tools_catalog:
self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name)

@property
def tool_calls_count(self) -> int:
Expand All @@ -256,7 +259,11 @@ def _tool_catalog_message(self) -> ConversationMessage:
"""
return ConversationMessage(
role="tool_catalog",
content=json.dumps(self.tools_catalog.model_dump(), separators=(",", ":")),
content=(
json.dumps(self.tools_catalog.model_dump(), separators=(",", ":"))
if self.tools_catalog
else ""
),
)

def add_message(self, message: ConversationMessage) -> None:
Expand All @@ -269,12 +276,24 @@ def add_message(self, message: ConversationMessage) -> None:
Raises:
ValueError: If the message role is invalid or adding it violates state constraints.
"""
logger.info(f"Adding message: {message}")

if message.role == "tool_calls":
self._handle_tool_call(message)
elif message.role == "tool_response":
self._handle_tool_response(message)
elif message.role == "assistant":
self._handle_assistant_response(message)

logger.info(f"Conversation state after adding message: {self}")

def _handle_assistant_response(self, message: ConversationMessage) -> None:
"""
Handles adding an assistant response message and updating state accordingly.
"""
logger.info(f"Handling assistant response: {message}")
self.raw_messages.append(message)
self.transition_to(ConversationStatus.COMPLETED)

def _handle_tool_call(self, message: ConversationMessage) -> None:
"""
Expand Down Expand Up @@ -348,16 +367,17 @@ def _remove_rest_api_tools(self) -> None:
"""
Removes all REST API tools from the catalog.
"""
self.tools_catalog = self.tools_catalog.remove_rest_api_tools()
if self.tools_catalog:
self.tools_catalog = self.tools_catalog.remove_rest_api_tools()

def update_current_tools(self) -> JsonSchemaTools:
def update_current_tools(self) -> JsonSchemaTools | None:
"""
Returns the appropriate tool catalog for the current conversation state:
- If status is WRAP_THINGS_UP, only return user-response tool.
- If status is RUNNING but there's no way to respond, return a catalog
that includes a user-response tool. Otherwise, return the current tools.
"""
if len(self.tools_catalog.items.anyOf) == 0:
if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0:
self.tools_catalog = JsonSchemaTools.only_user_response_tool()
elif self.status == ConversationStatus.WRAP_THINGS_UP:
logger.info(
Expand All @@ -369,11 +389,13 @@ def update_current_tools(self) -> JsonSchemaTools:
"After removing rest API tools, "
"we have {len(self.tools_catalog.items.anyOf)} tools.",
)
if len(self.tools_catalog.items.anyOf) == 0:
if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0:
self.tools_catalog = JsonSchemaTools.only_user_response_tool()
elif self.status == ConversationStatus.RUNNING:
if not self.can_respond():
self.tools_catalog = self.tools_catalog.with_user_response_tool()
self.tools_catalog = (
self.tools_catalog.with_user_response_tool() if self.tools_catalog else None
)
elif self.status in [
ConversationStatus.WAITING_TOOL_RESPONSE,
ConversationStatus.COMPLETED,
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_endpoint_handler_format_model_output(endpoint_handler: EndpointHandler)
Test parsing valid JSON from the pipeline's model output.
"""
mock_output = [{"generated_text": '[{"name": "tool1", "parameters": {"p": 1}}]'}]
tools = endpoint_handler._format_model_output(mock_output)
tools = endpoint_handler._format_model_output(mock_output, "auto")
assert len(tools) == 1
assert tools[0].name == "tool1"
assert tools[0].parameters == {"p": 1}
Expand Down
Loading
Loading