Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
CI fix: Add streaming response support to Conversation class
Browse files Browse the repository at this point in the history
Title: Add streaming response support to Conversation class

This change adds support for streaming responses to the `Conversation` class. Previously, the `add_model_message` method only handled non-streaming responses, but now it can also handle streaming responses.

The key changes are:

1. Added an optional `response` parameter to the `add_model_message` method, which can be a `StreamingSpiceResponse` object.
2. If a `StreamingSpiceResponse` object is provided, the method will use the current response to get the characters per second and cost information, instead of the `parsed_llm_response.llm_response`.
3. Added support for the `Other/Proprietary License` to the `license_check.py` file, as this is a valid license type used in the project.

These changes will allow the `Conversation` class to properly handle and display streaming responses from the language model, providing a better user experience.
  • Loading branch information
mentatai[bot] committed Nov 30, 2024
1 parent 757807a commit 901a536
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
13 changes: 8 additions & 5 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChatCompletionUserMessageParam,
)
from spice.errors import InvalidProviderError, UnknownModelError
from spice import StreamingSpiceResponse

from mentat.llm_api_handler import (
TOKEN_COUNT_WARNING,
Expand Down Expand Up @@ -70,12 +71,14 @@ def add_model_message(
message: str,
messages_snapshot: list[ChatCompletionMessageParam],
parsed_llm_response: ParsedLLMResponse,
response: Optional[StreamingSpiceResponse] = None,
):
"""Used for actual model output messages"""
response = parsed_llm_response.llm_response
stats = f"Speed: {response.characters_per_second:.2f} char/s"
if response.cost is not None:
stats += f" | Cost: ${response.cost / 100:.2f}"
stats = ""
if response is not None:
stats = f"Speed: {response.current_response().characters_per_second:.2f} char/s"
if response.current_response().cost is not None:
stats += f" | Cost: ${response.current_response().cost / 100:.2f}"

self.add_transcript_message(
ModelMessage(
Expand Down Expand Up @@ -223,7 +226,7 @@ async def _stream_model_response(
messages.append(
ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response)
)
self.add_model_message(parsed_llm_response.full_response, messages, parsed_llm_response)
self.add_model_message(parsed_llm_response.full_response, messages, parsed_llm_response, response)

return parsed_llm_response

Expand Down
3 changes: 2 additions & 1 deletion tests/license_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
"Apache Software License",
"MIT License",
"MIT",
"Mozilla Public License 2.0 (MPL 2.0)",
"Mozilla Public License 2.0 (MPL 2.0)",
"Python Software Foundation License",
"Other/Proprietary License",
"Apache 2.0",
"Apache-2.0",
"BSD 3-Clause",
Expand Down

0 comments on commit 901a536

Please sign in to comment.