diff --git a/mentat/code_edit_feedback.py b/mentat/code_edit_feedback.py index 3557da9d4..07ebdd59b 100644 --- a/mentat/code_edit_feedback.py +++ b/mentat/code_edit_feedback.py @@ -61,11 +61,7 @@ async def get_user_feedback_on_edits( conversation.add_message( ChatCompletionSystemMessageParam( role="system", - content=( - "User chose not to apply any of your changes. Please adjust" - " your previous plan and changes to reflect their feedback." - " Respond with a full new set of changes." - ), + content="User chose not to apply any of your changes.", ) ) conversation.add_user_message(user_response) diff --git a/mentat/conversation.py b/mentat/conversation.py index 6aab783c7..7940a6a2f 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -17,7 +17,7 @@ ) from mentat.errors import MentatError -from mentat.llm_api_handler import conversation_tokens, count_tokens, model_context_size +from mentat.llm_api_handler import count_tokens, model_context_size, prompt_tokens from mentat.session_context import SESSION_CONTEXT from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage from mentat.utils import add_newline @@ -75,7 +75,7 @@ async def display_token_count(self): content=await code_context.get_code_message("", max_tokens=0), ) ] - tokens = conversation_tokens( + tokens = prompt_tokens( messages, config.model, ) @@ -187,7 +187,7 @@ async def _stream_model_response( start_time = default_timer() - num_prompt_tokens = conversation_tokens(messages, config.model) + num_prompt_tokens = prompt_tokens(messages, config.model) context_size = model_context_size(config.model) if context_size: if num_prompt_tokens > context_size - config.token_buffer: @@ -205,7 +205,7 @@ async def _stream_model_response( ) response = await llm_api_handler.call_llm_api( messages, - config.model, + config.model stream=True, response_format=parser.response_format(), ) @@ -237,7 +237,7 @@ async def get_model_response(self) -> list[FileEdit]: messages_snapshot = self.get_messages() # Rebuild code context with active code and available tokens - tokens = conversation_tokens(messages_snapshot, config.model) + tokens = prompt_tokens(messages_snapshot, config.model) loading_multiplier = 1.0 if config.auto_context else 0.0 try: @@ -275,7 +275,7 @@ async def get_model_response(self) -> list[FileEdit]: cost_tracker.display_api_call_stats( num_prompt_tokens, count_tokens( - parsed_llm_response.full_response, config.model, full_message=True + parsed_llm_response.full_response, config.model, full_message=False ), config.model, time_elapsed, diff --git a/mentat/embeddings.py b/mentat/embeddings.py index bef812cfa..ae40b082a 100644 --- a/mentat/embeddings.py +++ b/mentat/embeddings.py @@ -128,7 +128,9 @@ async def get_feature_similarity_scores( num_prompt_tokens = 0 if not database.exists(prompt_checksum): items_to_embed[prompt_checksum] = prompt - items_to_embed_tokens[prompt_checksum] = count_tokens(prompt, embedding_model) + items_to_embed_tokens[prompt_checksum] = count_tokens( + prompt, embedding_model, False + ) for feature, checksum, token in zip(features, checksums, tokens): if token > max_model_tokens: continue diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index 4b6e49577..ab8bbde50 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -46,11 +46,12 @@ def chunk_to_lines(chunk: ChatCompletionChunk) -> list[str]: return ("" if content is None else content).splitlines(keepends=True) -def count_tokens(message: str, model: str, full_message: bool = False) -> int: +def count_tokens(message: str, model: str, full_message: bool) -> int: """ - Calculates the tokens in this message. Will NOT be accurate for a full conversation! - Use conversation_tokens to get the exact amount of tokens in a conversation. - If full_message is true, will include the extra 4 tokens used in a chat completion by this message. + Calculates the tokens in this message. Will NOT be accurate for a full prompt! + Use prompt_tokens to get the exact amount of tokens for a prompt. + If full_message is true, will include the extra 4 tokens used in a chat completion by this message + if this message is part of a prompt. The majority of the time, you'll want full_message to be true. """ try: encoding = tiktoken.encoding_for_model(model) @@ -61,9 +62,9 @@ def count_tokens(message: str, model: str, full_message: bool = False) -> int: ) -def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str): +def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str): """ - Returns the number of tokens used by a full conversation. + Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion. Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens """ try: @@ -73,7 +74,8 @@ def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str): num_tokens = 0 for message in messages: - # every message follows {role/name}\n{content}\n + # every message follows <|start|>{role/name}\n{content}<|end|>\n + # this has 5 tokens (start token, role, \n, end token, \n), but we count the role token later num_tokens += 4 for key, value in message.items(): if isinstance(value, list) and key == "content": @@ -98,7 +100,7 @@ def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str): num_tokens += len(encoding.encode(value)) if key == "name": # if there's a name, the role is omitted num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with assistant + num_tokens += 2 # every reply is primed with <|start|>assistant return num_tokens diff --git a/tests/llm_api_handler_test.py b/tests/llm_api_handler_test.py index 6a2e34329..31fd1d691 100644 --- a/tests/llm_api_handler_test.py +++ b/tests/llm_api_handler_test.py @@ -3,17 +3,17 @@ from PIL import Image -from mentat.llm_api_handler import conversation_tokens +from mentat.llm_api_handler import prompt_tokens -def test_conversation_tokens(): +def test_prompt_tokens(): messages = [ {"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there! How can I help you today?"}, ] model = "gpt-4-vision-preview" - assert conversation_tokens(messages, model) == 24 + assert prompt_tokens(messages, model) == 24 # An image that must be scaled twice and then fits in 6 512x512 panels img = Image.new("RGB", (768 * 4, 1050 * 4), color="red") @@ -30,4 +30,4 @@ def test_conversation_tokens(): } ) - assert conversation_tokens(messages, model) == 24 + 6 * 170 + 85 + 5 + assert prompt_tokens(messages, model) == 24 + 6 * 170 + 85 + 5