Skip to content

Commit

Permalink
fix: embeddings that are too large
Browse files Browse the repository at this point in the history
  • Loading branch information
tnunamak committed Mar 7, 2024
1 parent cee0022 commit 313a2df
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 31 deletions.
6 changes: 5 additions & 1 deletion selfie/api/index_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

router = APIRouter()

from selfie.config import get_app_config

config = get_app_config()

@router.get("/index_documents")
async def get_documents(offset: int = 0, limit: int = 10):
Expand Down Expand Up @@ -78,7 +81,8 @@ async def load_data(request: DataLoaderRequest):
print(documents)

text_parser = SentenceSplitter(
chunk_size=1024,
chunk_size=config.embedding_chunk_size,
chunk_overlap=config.embedding_chunk_overlap,
# separator=" ",
)

Expand Down
2 changes: 2 additions & 0 deletions selfie/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class AppConfig(BaseModel):
local_gpu_model: str = Field(default='TheBloke/Mistral-7B-OpenOrca-GPTQ', description="Local GPU model")
local_functionary_model: str = Field(default="meetkai/functionary-7b-v2-GGUF/functionary-7b-v2.q4_0.gguf", description="Local functionary model")
hosted_model: str = Field(default="openai/gpt-3.5-turbo", description="Hosted model")
embedding_chunk_size: int = Field(default=512, description="Embedding chunk size")
embedding_chunk_overlap: int = Field(default=50, description="Embedding chunk overlap")

@property
def base_url(self):
Expand Down
10 changes: 9 additions & 1 deletion selfie/connectors/text_files/connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC
from typing import Any, List

from selfie.config import get_app_config


from llama_index.core.node_parser import SentenceSplitter

from selfie.connectors.base_connector import BaseConnector
Expand All @@ -9,6 +12,8 @@
from selfie.types.documents import DocumentDTO
from selfie.utils import data_uri_to_dict

config = get_app_config()


class TextFilesConfiguration(BaseModel):
files: List[str]
Expand Down Expand Up @@ -46,5 +51,8 @@ def transform_for_embedding(self, configuration: dict[str, Any], documents: List
source_document_id=document.id,
)
for document in documents
for text_chunk in SentenceSplitter(chunk_size=1024).split_text(document.content)
for text_chunk in SentenceSplitter(
chunk_size=config.embedding_chunk_size,
chunk_overlap=config.embedding_chunk_overlap,
).split_text(document.content)
]
25 changes: 13 additions & 12 deletions selfie/data_generators/chat_training_data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#!/usr/bin/env python3

from typing import List, Dict
from typing import List, Dict, Callable
from enum import Enum
import os
import time
import json
import random
import argparse
import logging
from itertools import groupby

from enum import Enum

from selfie.parsers.chat import ChatFileParser, Parser
from selfie.types.share_gpt import ShareGPTMessage

logger = logging.getLogger(__name__)


class Strategy(Enum):
BUNDLE = 'bundle'
Expand Down Expand Up @@ -78,26 +79,26 @@ def extract_message_bundles(conversations: List[ShareGPTMessage]):
return message_bundles

@staticmethod
def group_messages_into_chunks(conversations: List[ShareGPTMessage], overlap: int = 0, max_messages: int = 3, max_characters: int = 0) -> List[List[ShareGPTMessage]]:
def group_messages_into_chunks(conversations: List[ShareGPTMessage], tokenizer: Callable, overlap: int = 0, max_messages: int = 3, max_tokens: int = 0) -> List[List[ShareGPTMessage]]:
chunks = []
index = 0
while index < len(conversations):
end_index = index + max_messages
chunk = conversations[index:end_index]

# If there's a max characters limit, adjust the chunk to not exceed it
if max_characters > 0:
characters_count = sum(len(msg.value.split()) for msg in chunk)
while characters_count > max_characters and len(chunk) > 0:
chunk.pop() # Remove the last message
characters_count = sum(len(msg.value.split()) for msg in chunk)
if max_tokens > 0:
tokens_count = sum(len(tokenizer(msg.value)) for msg in chunk)
while tokens_count > max_tokens and len(chunk) > 0:
if len(chunk) == 1:
logger.warning(f"Warning: A single message exceeds the max tokens limit ({max_tokens}).")
chunk.pop()
tokens_count = sum(len(tokenizer(msg.value)) for msg in chunk)

chunks.append(chunk)
index += max_messages - overlap

return chunks


@staticmethod
def generate_sharegpt_jsonl_line(messages: List[ShareGPTMessage]) -> str:
"""
Expand Down
5 changes: 4 additions & 1 deletion selfie/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def _map_selfie_documents_to_index_documents(selfie_document: DocumentDTO):
timestamp=DataManager._extract_timestamp(selfie_document),
source_document_id=selfie_document.id,
)
for text_chunk in SentenceSplitter(chunk_size=1024).split_text(selfie_document.content)
for text_chunk in SentenceSplitter(
chunk_size=config.embedding_chunks_size,
chunk_overlap=config.embedding_chunk_overlap
).split_text(selfie_document.content)
]

@staticmethod
Expand Down
60 changes: 45 additions & 15 deletions selfie/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from typing import Optional, List, Dict, Any, Coroutine, Callable

import humanize
import logging
import tiktoken
from llama_index.core.node_parser import SentenceSplitter

from selfie.config import get_app_config
from selfie.data_generators.chat_training_data import (
Expand All @@ -18,7 +21,6 @@
from selfie.embeddings.recency_scorer import RecencyScorer
from selfie.embeddings.relevance_scorer import RelevanceScorer
from txtai.embeddings import Embeddings
import logging

from txtai.pipeline import LLM

Expand All @@ -28,12 +30,25 @@

config = get_app_config()

llm = LLM(
verbose=config.verbose,
path=config.local_model,
method="llama.cpp",
n_ctx=8192,
n_gpu_layers=-1 if config.gpu else 0,

def get_default_completion():
return LLM(
verbose=config.verbose,
path=config.local_model,
method="llama.cpp",
n_ctx=8192,
n_gpu_layers=-1 if config.gpu else 0,
)


# TODO: Probably a minor issue, so hard-coding the tokenizer for now:
# 1. The default tokenizer should probably be based on the user's default/configured model
# 2. If the user changes their default model, already-indexed documents could be larger than max_embedding_size_tokens
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
splitter = SentenceSplitter(
tokenizer=tokenizer,
chunk_size=config.embedding_chunk_size,
chunk_overlap=config.embedding_chunk_overlap
)


Expand All @@ -51,11 +66,7 @@ def __init__(self, character_name, storage_path: str = config.embeddings_storage
self.storage_path = os.path.join(storage_path, "index")
os.makedirs(storage_path, exist_ok=True)

async def completion_async(prompt):
return llm(prompt)

self.completion = completion or completion_async

self.completion = completion or get_default_completion()
self.character_name = character_name
self.embeddings = Embeddings(
sqlite={"wal": True},
Expand Down Expand Up @@ -149,11 +160,30 @@ async def enqueue_delete(self, ids: List[int]):
def map_share_gpt_data(
conversation: List[ShareGPTMessage], source: str = "Unknown", source_document_id: int = None
) -> List[EmbeddingDocumentModel]:
chunks = ChatTrainingDataGenerator.group_messages_into_chunks(
conversation, overlap=1, max_messages=8, max_characters=0
conversation_with_chunked_messages = [
ShareGPTMessage(**{
"from": msg.from_user,
"value": chunk,
"timestamp": msg.timestamp,
})
for msg in conversation
for chunk in (
splitter.split_text(msg.value)
if len(tokenizer(msg.value)) > config.embedding_chunk_size
else [msg.value]
)
]

message_chunks = ChatTrainingDataGenerator.group_messages_into_chunks(
conversation_with_chunked_messages,
overlap=2,
max_messages=32,
max_tokens=config.embedding_chunk_size,
tokenizer=tokenizer
)

documents = []
for i, conv in enumerate(chunks):
for i, conv in enumerate(message_chunks):
if any("REDACTED" in msg.value for msg in conv):
continue
last_user = ""
Expand Down
3 changes: 2 additions & 1 deletion selfie/parsers/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self, blacklist_patterns=None, rewrite_placeholder: str = "REDACTED
self.parser_cache = {}
self.blacklist_patterns = [
re.compile(pattern, re.IGNORECASE)
for pattern in default_blacklist_patterns + (blacklist_patterns or [])
# TODO: Disabling blacklisting until it is more configurable
for pattern in [] #default_blacklist_patterns + (blacklist_patterns or [])
]
self.rewrite_placeholder = rewrite_placeholder

Expand Down

0 comments on commit 313a2df

Please sign in to comment.