Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Update embeddings batching based on new understanding of limits (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins authored Dec 23, 2023
1 parent 3195310 commit 9498bdd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 72 deletions.
99 changes: 37 additions & 62 deletions mentat/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json
import os
import sqlite3
Expand All @@ -18,7 +17,7 @@
from mentat.session_input import ask_yes_no
from mentat.utils import mentat_dir_path, sha256

MAX_SIMULTANEOUS_REQUESTS = 10
EMBEDDINGS_API_BATCH_SIZE = 2048


class EmbeddingsDatabase:
Expand Down Expand Up @@ -67,35 +66,6 @@ def __del__(self):
database = EmbeddingsDatabase()


def _batch_ffd(data: dict[str, int], batch_size: int) -> list[list[str]]:
"""Batch files using the First Fit Decreasing algorithm."""
# Sort the data by the length of the strings in descending order
sorted_data = sorted(data.items(), key=lambda x: x[1], reverse=True)
batches = list[list[str]]()
for key, value in sorted_data:
# Place each item in the first batch that it fits in
placed = False
for batch in batches:
if sum(data[k] for k in batch) + value <= batch_size:
batch.append(key)
placed = True
break
if not placed:
batches.append([key])
return batches


embedding_api_semaphore = asyncio.Semaphore(MAX_SIMULTANEOUS_REQUESTS)


async def _fetch_embeddings(model: str, batch: list[str]):
ctx = SESSION_CONTEXT.get()

async with embedding_api_semaphore:
response = await ctx.llm_api_handler.call_embedding_api(batch, model)
return response


def _cosine_similarity(v1: list[float], v2: list[float]) -> float:
"""Calculate the cosine similarity between two vectors."""
dot_product = np.dot(v1, v2)
Expand All @@ -115,33 +85,36 @@ async def get_feature_similarity_scores(
stream = session_context.stream
cost_tracker = session_context.cost_tracker
embedding_model = session_context.config.embedding_model
llm_api_handler = session_context.llm_api_handler

max_model_tokens = model_context_size(embedding_model)
if max_model_tokens is None:
raise MentatError(f"Missing model context size for {embedding_model}.")
prompt_tokens = count_tokens(prompt, embedding_model, False)
if prompt_tokens > max_model_tokens:
stream.send(
f"Warning: Prompt contains {prompt_tokens} tokens, but the model"
f" can only handle {max_model_tokens} tokens. Ignoring embeddings."
)
return [0.0 for _ in features]

# Keep things in the same order
prompt_checksum = sha256(prompt)
checksums: list[str] = [f.get_checksum() for f in features]
tokens: list[int] = await count_feature_tokens(features, embedding_model)

# Make a checksum:content dict of all items that need to be embedded
items_to_embed = dict[str, str]()
items_to_embed_tokens = dict[str, int]()
prompt_checksum = sha256(prompt)
num_prompt_tokens = 0
embed_texts = list[str]()
embed_checksums = list[str]()
embed_tokens = list[int]()
if not database.exists(prompt_checksum):
items_to_embed[prompt_checksum] = prompt
items_to_embed_tokens[prompt_checksum] = count_tokens(
prompt, embedding_model, False
)
embed_texts.append(prompt)
embed_checksums.append(prompt_checksum)
embed_tokens.append(prompt_tokens)
for feature, checksum, token in zip(features, checksums, tokens):
if token > max_model_tokens:
continue
if not database.exists(checksum):
feature_content = feature.get_code_message()
# Remove line numbering
items_to_embed[checksum] = "\n".join(feature_content)
items_to_embed_tokens[checksum] = token
num_prompt_tokens += token
embed_texts.append("\n".join(feature.get_code_message()))
embed_checksums.append(checksum)
embed_tokens.append(token)

# If it costs more than $1, get confirmation from user.
cost = model_price_per_1000_tokens(embedding_model)
Expand All @@ -151,40 +124,42 @@ async def get_feature_similarity_scores(
color="light_yellow",
)
else:
expected_cost = (num_prompt_tokens / 1000) * cost[0]
expected_cost = (sum(embed_tokens) / 1000) * cost[0]
if expected_cost > 1.0:
stream.send(
f"Embedding {num_prompt_tokens} tokens will cost ${cost[0]:.2f}."
f"Embedding {sum(embed_tokens)} tokens will cost ${cost[0]:.2f}."
" Continue anyway?"
)
if not await ask_yes_no(default_yes=True):
stream.send("Ignoring embeddings for now.")
return [0.0 for _ in checksums]

# Fetch embeddings in batches
batches = _batch_ffd(items_to_embed_tokens, max_model_tokens)

tasks = list[tuple[asyncio.Task[list[list[float]]], list[str]]]()
for batch in batches:
batch_content = [items_to_embed[k] for k in batch]
task = asyncio.create_task(_fetch_embeddings(embedding_model, batch_content))
tasks.append((task, batch))
for i, (task, batch) in enumerate(tasks):
n_batches = len(embed_texts) // EMBEDDINGS_API_BATCH_SIZE + 1
for batch in range(n_batches):
if loading_multiplier:
stream.send(
f"Fetching embeddings, batch {i+1}/{len(batches)}",
f"Fetching embeddings, batch {batch+1}/{n_batches}",
channel="loading",
progress=(100 / len(batches)) * loading_multiplier,
progress=(100 / n_batches) * loading_multiplier,
)
start_time = default_timer()
response = await task
i_start, i_end = (
batch * EMBEDDINGS_API_BATCH_SIZE,
(batch + 1) * EMBEDDINGS_API_BATCH_SIZE,
)
_texts = embed_texts[i_start:i_end]
_checksums = embed_checksums[i_start:i_end]
_tokens = embed_tokens[i_start:i_end]

response = await llm_api_handler.call_embedding_api(_texts, embedding_model)
cost_tracker.log_api_call_stats(
sum(items_to_embed_tokens[k] for k in batch),
sum(_tokens),
0,
embedding_model,
start_time - default_timer(),
)
database.set({k: v for k, v in zip(batch, response)})
database.set({k: v for k, v in zip(_checksums, response)})

# Calculate similarity score for each feature
prompt_embedding = database.get([prompt_checksum])[prompt_checksum]
Expand Down
12 changes: 2 additions & 10 deletions tests/embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import pytest

from mentat.code_feature import CodeFeature
from mentat.embeddings import _batch_ffd, get_feature_similarity_scores


def test_batch_ffd():
data = {"a": 4, "b": 5, "c": 3, "d": 2}
batch_size = 6
result = _batch_ffd(data, batch_size)
expected = [["b"], ["a", "d"], ["c"]]
assert result == expected
from mentat.embeddings import get_feature_similarity_scores


def _make_code_feature(path, text):
Expand All @@ -29,10 +21,10 @@ async def test_get_feature_similarity_scores(mocker, mock_call_embedding_api):
]
mock_call_embedding_api.set_embedding_values(
[
[0.7, 0.7, 0.7], # The prompt
[0.4, 0.4, 0.4],
[0.5, 0.6, 0.7],
[0.69, 0.7, 0.71],
[0.7, 0.7, 0.7], # The prompt
]
)
result = await get_feature_similarity_scores(prompt, features)
Expand Down

0 comments on commit 9498bdd

Please sign in to comment.