Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Merge branch 'main' into add-run
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Nov 27, 2023
2 parents 8ab785f + 4f82f2f commit a5b6379
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
6 changes: 1 addition & 5 deletions mentat/code_edit_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ async def get_user_feedback_on_edits(
conversation.add_message(
ChatCompletionSystemMessageParam(
role="system",
content=(
"User chose not to apply any of your changes. Please adjust"
" your previous plan and changes to reflect their feedback."
" Respond with a full new set of changes."
),
content="User chose not to apply any of your changes.",
)
)
conversation.add_user_message(user_response)
Expand Down
12 changes: 6 additions & 6 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

from mentat.errors import MentatError
from mentat.llm_api_handler import conversation_tokens, count_tokens, model_context_size
from mentat.llm_api_handler import count_tokens, model_context_size, prompt_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage
from mentat.utils import add_newline
Expand Down Expand Up @@ -75,7 +75,7 @@ async def display_token_count(self):
content=await code_context.get_code_message("", max_tokens=0),
)
]
tokens = conversation_tokens(
tokens = prompt_tokens(
messages,
config.model,
)
Expand Down Expand Up @@ -187,7 +187,7 @@ async def _stream_model_response(

start_time = default_timer()

num_prompt_tokens = conversation_tokens(messages, config.model)
num_prompt_tokens = prompt_tokens(messages, config.model)
context_size = model_context_size(config.model)
if context_size:
if num_prompt_tokens > context_size - config.token_buffer:
Expand All @@ -205,7 +205,7 @@ async def _stream_model_response(
)
response = await llm_api_handler.call_llm_api(
messages,
config.model,
config.model
stream=True,
response_format=parser.response_format(),
)
Expand Down Expand Up @@ -237,7 +237,7 @@ async def get_model_response(self) -> list[FileEdit]:
messages_snapshot = self.get_messages()

# Rebuild code context with active code and available tokens
tokens = conversation_tokens(messages_snapshot, config.model)
tokens = prompt_tokens(messages_snapshot, config.model)

loading_multiplier = 1.0 if config.auto_context else 0.0
try:
Expand Down Expand Up @@ -275,7 +275,7 @@ async def get_model_response(self) -> list[FileEdit]:
cost_tracker.display_api_call_stats(
num_prompt_tokens,
count_tokens(
parsed_llm_response.full_response, config.model, full_message=True
parsed_llm_response.full_response, config.model, full_message=False
),
config.model,
time_elapsed,
Expand Down
4 changes: 3 additions & 1 deletion mentat/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ async def get_feature_similarity_scores(
num_prompt_tokens = 0
if not database.exists(prompt_checksum):
items_to_embed[prompt_checksum] = prompt
items_to_embed_tokens[prompt_checksum] = count_tokens(prompt, embedding_model)
items_to_embed_tokens[prompt_checksum] = count_tokens(
prompt, embedding_model, False
)
for feature, checksum, token in zip(features, checksums, tokens):
if token > max_model_tokens:
continue
Expand Down
18 changes: 10 additions & 8 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def chunk_to_lines(chunk: ChatCompletionChunk) -> list[str]:
return ("" if content is None else content).splitlines(keepends=True)


def count_tokens(message: str, model: str, full_message: bool = False) -> int:
def count_tokens(message: str, model: str, full_message: bool) -> int:
"""
Calculates the tokens in this message. Will NOT be accurate for a full conversation!
Use conversation_tokens to get the exact amount of tokens in a conversation.
If full_message is true, will include the extra 4 tokens used in a chat completion by this message.
Calculates the tokens in this message. Will NOT be accurate for a full prompt!
Use prompt_tokens to get the exact amount of tokens for a prompt.
If full_message is true, will include the extra 4 tokens used in a chat completion by this message
if this message is part of a prompt. The majority of the time, you'll want full_message to be true.
"""
try:
encoding = tiktoken.encoding_for_model(model)
Expand All @@ -61,9 +62,9 @@ def count_tokens(message: str, model: str, full_message: bool = False) -> int:
)


def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str):
def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str):
"""
Returns the number of tokens used by a full conversation.
Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion.
Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens
"""
try:
Expand All @@ -73,7 +74,8 @@ def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str):

num_tokens = 0
for message in messages:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
# every message follows <|start|>{role/name}\n{content}<|end|>\n
# this has 5 tokens (start token, role, \n, end token, \n), but we count the role token later
num_tokens += 4
for key, value in message.items():
if isinstance(value, list) and key == "content":
Expand All @@ -98,7 +100,7 @@ def conversation_tokens(messages: list[ChatCompletionMessageParam], model: str):
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens -= 1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
num_tokens += 2 # every reply is primed with <|start|>assistant
return num_tokens


Expand Down
8 changes: 4 additions & 4 deletions tests/llm_api_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from PIL import Image

from mentat.llm_api_handler import conversation_tokens
from mentat.llm_api_handler import prompt_tokens


def test_conversation_tokens():
def test_prompt_tokens():
messages = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I help you today?"},
]
model = "gpt-4-vision-preview"

assert conversation_tokens(messages, model) == 24
assert prompt_tokens(messages, model) == 24

# An image that must be scaled twice and then fits in 6 512x512 panels
img = Image.new("RGB", (768 * 4, 1050 * 4), color="red")
Expand All @@ -30,4 +30,4 @@ def test_conversation_tokens():
}
)

assert conversation_tokens(messages, model) == 24 + 6 * 170 + 85 + 5
assert prompt_tokens(messages, model) == 24 + 6 * 170 + 85 + 5

0 comments on commit a5b6379

Please sign in to comment.