Skip to content

Commit

Permalink
Merge pull request #271 from leeeizhang/lei/memory-cli
Browse files Browse the repository at this point in the history
[MRG] add mle memory crud api
  • Loading branch information
huangyz0918 authored Dec 2, 2024
2 parents b9765e7 + bd5fe58 commit 766e505
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 15 deletions.
67 changes: 67 additions & 0 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,70 @@ def integrate(reset):
"token": pickle.dumps(token, fix_imports=False),
}
write_config(config)


@cli.command()
@click.option('--add', default=None, help='Add files or directories into the local memory.')
@click.option('--rm', default=None, help='Remove files or directories into the local memory.')
@click.option('--update', default=None, help='Update files or directories into the local memory.')
def memory(add, rm, update):
memory = LanceDBMemory(os.getcwd())
path = add or rm or update
if path is None:
return

source_files = []
if os.path.isdir(path):
source_files = list_files(path, ['*.py'])
else:
source_files = [path]

working_dir = os.getcwd()
table_name = 'mle_chat_' + working_dir.split('/')[-1]
chunker = CodeChunker(os.path.join(working_dir, '.mle', 'cache'), 'py')
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console,
) as progress:
process_task = progress.add_task("Processing files...", total=len(source_files))

for file_path in source_files:
raw_code = read_file(file_path)
progress.update(
process_task,
advance=1,
description=f"Process {os.path.basename(file_path)} for memory..."
)

if add:
# add file into memory
chunks = chunker.chunk(raw_code, token_limit=100)
memory.add(
texts=list(chunks.values()),
table_name=table_name,
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
)
elif rm:
# remove file from memory
memory.delete_by_metadata(
key="file",
value=file_path,
table_name=table_name,
)
elif update:
# update file into memory
chunks = chunker.chunk(raw_code, token_limit=100)
memory.delete_by_metadata(
key="file",
value=file_path,
table_name=table_name,
)
memory.add(
texts=list(chunks.values()),
table_name=table_name,
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
)
61 changes: 46 additions & 15 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def __init__(self, project_path: str):
else:
raise NotImplementedError

def _open_table(self, table_name: str = None):
"""
Open a LanceDB table by table name. (Return None if not exists)
Args:
table_name (Optional[str]): The name of the table. Defaults to self.table_name.
"""
table_name = table_name or self.table_name
try:
table = self.client.open_table(table_name)
except FileNotFoundError:
return None
return table

def add(
self,
texts: List[str],
Expand Down Expand Up @@ -73,7 +86,7 @@ def add(
table = self.client.create_table(table_name, data=data)
table.create_fts_index("id")
else:
self.client.open_table(table_name).add(data=data)
self._open_table(table_name).add(data=data)

return ids

Expand All @@ -90,8 +103,10 @@ def query(self, query_texts: List[str], table_name: Optional[str] = None, n_resu
List[List[dict]]: A list of results for each query text, each result being a dictionary with
keys such as "vector", "text", and "id".
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return []

query_embeds = self.text_embedding.compute_source_embeddings(query_texts)

results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
Expand All @@ -107,8 +122,10 @@ def list_all_keys(self, table_name: Optional[str] = None):
Returns:
List[str]: A list of all IDs in the table.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return []

return [item["id"] for item in table.search(query_type="fts").to_list()]

def get(self, record_id: str, table_name: Optional[str] = None):
Expand All @@ -122,8 +139,10 @@ def get(self, record_id: str, table_name: Optional[str] = None):
Returns:
List[dict]: A list containing the matching record, or an empty list if not found.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return []

return table.search(query_type="fts") \
.where(f"id = '{record_id}'") \
.limit(1).to_list()
Expand All @@ -141,8 +160,10 @@ def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None
Returns:
List[dict]: A list of records matching the metadata criteria.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return []

return table.search(query_type="fts") \
.where(f"metadata.{key} = '{value}'") \
.limit(n_results).to_list()
Expand All @@ -158,8 +179,10 @@ def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
Returns:
bool: True if the deletion was successful, False otherwise.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return True

return table.delete(f"id = '{record_id}'")

def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None):
Expand All @@ -174,8 +197,10 @@ def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = N
Returns:
bool: True if deletion was successful, False otherwise.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return True

return table.delete(f"metadata.{key} = '{value}'")

def drop(self, table_name: Optional[str] = None) -> bool:
Expand All @@ -189,6 +214,10 @@ def drop(self, table_name: Optional[str] = None) -> bool:
bool: True if the table was successfully dropped, False otherwise.
"""
table_name = table_name or self.table_name
table = self._open_table(table_name)
if table is None:
return True

return self.client.drop_table(table_name)

def count(self, table_name: Optional[str] = None) -> int:
Expand All @@ -201,8 +230,10 @@ def count(self, table_name: Optional[str] = None) -> int:
Returns:
int: The number of records in the table.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
table = self._open_table(table_name)
if table is None:
return 0

return table.count_rows()

def reset(self) -> None:
Expand Down

0 comments on commit 766e505

Please sign in to comment.