Skip to content

Commit

Permalink
feat: allow models to return non-json / unenforced responses (#12)
Browse files Browse the repository at this point in the history
* feat: turn off enforcement

* refactor: add text response option to handler response

* refactor: handle non-tool formatting

* refactor: cleanup logic

* refactor: handle non-tool call assistant messages
  • Loading branch information
NJordan72 authored Jan 10, 2025
1 parent 170ee19 commit f7d28e4
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 32 deletions.
76 changes: 58 additions & 18 deletions narrative_llm_tools/handlers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
class HandlerResponse(BaseModel):
"""Response from the handler."""

tool_calls: list[dict[str, Any]]
warnings: list[str] | None
text_response: str | None = None
tool_calls: list[dict[str, Any]] | None = None
warnings: list[str] | None = None


class ModelConfig(BaseModel):
Expand Down Expand Up @@ -285,18 +286,25 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
}:
self._process_conversation_turn(conversation_state)

return_msg = json.loads(conversation_state.get_last_message().content)
if conversation_state.tool_choice != "none":
return_msg = json.loads(conversation_state.get_last_message().content)

if not isinstance(return_msg, list):
raise ModelOutputError("Model output is not a list of tool calls.")

for tool_call in return_msg:
if not isinstance(tool_call, dict):
if not isinstance(return_msg, list):
raise ModelOutputError("Model output is not a list of tool calls.")

return HandlerResponse(tool_calls=return_msg, warnings=None).model_dump(
exclude_none=True
)
for tool_call in return_msg:
if not isinstance(tool_call, dict):
raise ModelOutputError("Model output is not a list of tool calls.")

return HandlerResponse(
tool_calls=return_msg, warnings=None, text_response=None
).model_dump(exclude_none=True)
else:
return HandlerResponse(
tool_calls=None,
text_response=conversation_state.get_last_message().content,
warnings=None,
).model_dump(exclude_none=True)

except (
ValidationError,
Expand All @@ -316,17 +324,42 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
def _process_conversation_turn(self, state: ConversationState) -> None:
"""Process a single turn of the conversation."""
conversation_text = self._format_conversation(state)
format_enforcer = get_format_enforcer(self.pipeline.tokenizer, state.update_current_tools())
format_enforcer = self._get_format_enforcer(state)

model_output = self._generate_prediction(
conversation_text, format_enforcer, state.pipeline_params
)

tool_calls = self._format_model_output(model_output)
serialized = [tool.model_dump() for tool in tool_calls]
state.add_message(ConversationMessage(role="tool_calls", content=json.dumps(serialized)))
formatted_output = self._format_model_output(model_output, state.tool_choice)

if state.tool_choice != "none":
if not isinstance(formatted_output, list):
logger.warning("Expected list of tool calls but got different type")
return

serialized = [tool.model_dump() for tool in formatted_output]
state.add_message(
ConversationMessage(role="tool_calls", content=json.dumps(serialized))
)

if state.only_called_rest_api_tools(tool_calls):
self._execute_tool_calls(tool_calls, state)
if state.only_called_rest_api_tools(formatted_output):
self._execute_tool_calls(formatted_output, state)
else:
if not isinstance(formatted_output, str):
logger.warning("Expected string response but got different type")
return

state.add_message(ConversationMessage(role="assistant", content=formatted_output))

def _get_format_enforcer(self, state: ConversationState) -> FormatEnforcer | None:
"""Get the format enforcer based on current tools state."""
if not state.tools_catalog:
return None

current_tools = state.update_current_tools()
return (
get_format_enforcer(self.pipeline.tokenizer, current_tools) if current_tools else None
)

def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState) -> None:
"""Execute tool calls and update conversation state."""
Expand Down Expand Up @@ -400,7 +433,11 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState)
if return_to_user and state.status != ConversationStatus.COMPLETED:
state.transition_to(ConversationStatus.COMPLETED)

def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool]:
def _format_model_output(
self,
model_output: list[dict[str, Any]],
tool_choice: Literal["required", "none", "auto"],
) -> list[Tool] | str:
"""Format the model output into a list of dictionaries."""
if not model_output:
return []
Expand All @@ -413,6 +450,9 @@ def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool]
if generated_text is None:
raise ModelOutputError("No generated_text found in the model output.")

if tool_choice == "none":
return generated_text

try:
logger.debug(f"Generated text: {generated_text}")
parsed_output: list[Tool] = [
Expand Down
48 changes: 35 additions & 13 deletions narrative_llm_tools/state/conversation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ConversationState(BaseModel):
raw_messages: list[ConversationMessage]
max_tool_rounds: int = 5
tool_choice: Literal["required", "auto", "none"] = "required"
tools_catalog: JsonSchemaTools = JsonSchemaTools.only_user_response_tool()
tools_catalog: JsonSchemaTools | None = JsonSchemaTools.only_user_response_tool()
pipeline_params: dict[str, Any]
status: ConversationStatus = ConversationStatus.RUNNING

Expand All @@ -103,14 +103,14 @@ def from_api_request(cls, request_data: dict[str, Any]) -> "ConversationState":
if k not in cls.RESERVED_KEYS and not k.startswith("_")
}

tool_choice = request_data.get("tool_choice", "required")
tools_data = request_data.get("tools", {})
tools_instance = (
JsonSchemaTools.model_validate(tools_data)
if tools_data
else JsonSchemaTools.only_user_response_tool()
if tools_data and tools_data != {} and tool_choice != "none"
else JsonSchemaTools.only_user_response_tool() if tool_choice != "none" else None
)

tool_choice = request_data.get("tool_choice", "required")
status = (
ConversationStatus.WRAP_THINGS_UP
if tool_choice == "none"
Expand Down Expand Up @@ -188,7 +188,9 @@ def _has_non_rest_tool(self) -> bool:
"""
Internal helper to check if there's at least one non-REST API tool available.
"""
return len(self.rest_api_names) != len(self.tools_catalog.items.anyOf)
return not self.tools_catalog or len(self.rest_api_names) != len(
self.tools_catalog.items.anyOf
)

def _has_rest_api_tools(self, content: str) -> bool:
"""Checks if the given content calls any REST API tools."""
Expand Down Expand Up @@ -235,13 +237,14 @@ def can_respond(self) -> bool:

def get_rest_api_catalog(self) -> dict[str, RestApiClient]:
"""Returns all REST API tools from the current catalog."""
return self.tools_catalog.get_rest_apis()
return self.tools_catalog.get_rest_apis() if self.tools_catalog else {}

def remove_tool(self, tool_name: str) -> None:
"""
Removes the specified tool from the catalog if it exists.
"""
self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name)
if self.tools_catalog:
self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name)

@property
def tool_calls_count(self) -> int:
Expand All @@ -256,7 +259,11 @@ def _tool_catalog_message(self) -> ConversationMessage:
"""
return ConversationMessage(
role="tool_catalog",
content=json.dumps(self.tools_catalog.model_dump(), separators=(",", ":")),
content=(
json.dumps(self.tools_catalog.model_dump(), separators=(",", ":"))
if self.tools_catalog
else ""
),
)

def add_message(self, message: ConversationMessage) -> None:
Expand All @@ -269,12 +276,24 @@ def add_message(self, message: ConversationMessage) -> None:
Raises:
ValueError: If the message role is invalid or adding it violates state constraints.
"""
logger.info(f"Adding message: {message}")

if message.role == "tool_calls":
self._handle_tool_call(message)
elif message.role == "tool_response":
self._handle_tool_response(message)
elif message.role == "assistant":
self._handle_assistant_response(message)

logger.info(f"Conversation state after adding message: {self}")

def _handle_assistant_response(self, message: ConversationMessage) -> None:
"""
Handles adding an assistant response message and updating state accordingly.
"""
logger.info(f"Handling assistant response: {message}")
self.raw_messages.append(message)
self.transition_to(ConversationStatus.COMPLETED)

def _handle_tool_call(self, message: ConversationMessage) -> None:
"""
Expand Down Expand Up @@ -348,16 +367,17 @@ def _remove_rest_api_tools(self) -> None:
"""
Removes all REST API tools from the catalog.
"""
self.tools_catalog = self.tools_catalog.remove_rest_api_tools()
if self.tools_catalog:
self.tools_catalog = self.tools_catalog.remove_rest_api_tools()

def update_current_tools(self) -> JsonSchemaTools:
def update_current_tools(self) -> JsonSchemaTools | None:
"""
Returns the appropriate tool catalog for the current conversation state:
- If status is WRAP_THINGS_UP, only return user-response tool.
- If status is RUNNING but there's no way to respond, return a catalog
that includes a user-response tool. Otherwise, return the current tools.
"""
if len(self.tools_catalog.items.anyOf) == 0:
if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0:
self.tools_catalog = JsonSchemaTools.only_user_response_tool()
elif self.status == ConversationStatus.WRAP_THINGS_UP:
logger.info(
Expand All @@ -369,11 +389,13 @@ def update_current_tools(self) -> JsonSchemaTools:
"After removing rest API tools, "
"we have {len(self.tools_catalog.items.anyOf)} tools.",
)
if len(self.tools_catalog.items.anyOf) == 0:
if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0:
self.tools_catalog = JsonSchemaTools.only_user_response_tool()
elif self.status == ConversationStatus.RUNNING:
if not self.can_respond():
self.tools_catalog = self.tools_catalog.with_user_response_tool()
self.tools_catalog = (
self.tools_catalog.with_user_response_tool() if self.tools_catalog else None
)
elif self.status in [
ConversationStatus.WAITING_TOOL_RESPONSE,
ConversationStatus.COMPLETED,
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_endpoint_handler_format_model_output(endpoint_handler: EndpointHandler)
Test parsing valid JSON from the pipeline's model output.
"""
mock_output = [{"generated_text": '[{"name": "tool1", "parameters": {"p": 1}}]'}]
tools = endpoint_handler._format_model_output(mock_output)
tools = endpoint_handler._format_model_output(mock_output, "auto")
assert len(tools) == 1
assert tools[0].name == "tool1"
assert tools[0].parameters == {"p": 1}
Expand Down
Loading

0 comments on commit f7d28e4

Please sign in to comment.