Skip to content

Commit

Permalink
Merge pull request #265 from MLSysOps/feat/local-rag
Browse files Browse the repository at this point in the history
[MRG] Code RAG for Chatbot
  • Loading branch information
huangyz0918 authored Nov 19, 2024
2 parents ce39ddf + 0e1ca7e commit dd7da4e
Show file tree
Hide file tree
Showing 13 changed files with 714 additions and 111 deletions.
37 changes: 25 additions & 12 deletions mle/agents/chat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import sys
import json
from rich.console import Console

from mle.function import *
from mle.utils import get_config, print_in_box, WorkflowCache
from mle.utils import get_config, WorkflowCache


class ChatAgent:

def __init__(self, model, working_dir='.', console=None):
def __init__(self, model, memory=None, working_dir='.', console=None):
"""
ChatAgent assists users with planning and debugging ML projects.
Expand All @@ -18,7 +16,10 @@ def __init__(self, model, working_dir='.', console=None):
config_data = get_config()

self.model = model
self.memory = memory
self.chat_history = []
if working_dir == '.':
working_dir = os.getcwd()
self.working_dir = working_dir
self.cache = WorkflowCache(working_dir, 'baseline')

Expand Down Expand Up @@ -56,7 +57,9 @@ def __init__(self, model, working_dir='.', console=None):
schema_search_papers_with_code,
schema_web_search,
schema_execute_command,
schema_preview_csv_data
schema_preview_csv_data,
schema_unzip_data,
schema_preview_zip_structure
]

if config_data.get('search_key'):
Expand All @@ -69,9 +72,9 @@ def __init__(self, model, working_dir='.', console=None):
advisor_report = self.cache.resume_variable("advisor_report")
self.sys_prompt += f"""
The overall project information: \n
{'Dataset: ' + dataset if dataset else ''} \n
{'Requirement: ' + ml_requirement if ml_requirement else ''} \n
{'Advisor: ' + advisor_report if advisor_report else ''} \n
{'Dataset: ' + str(dataset) if dataset else ''} \n
{'Requirement: ' + str(ml_requirement) if ml_requirement else ''} \n
{'Advisor: ' + str(advisor_report) if advisor_report else ''} \n
"""

self.chat_history.append({"role": 'system', "content": self.sys_prompt})
Expand All @@ -84,9 +87,8 @@ def greet(self):
Returns:
str: The generated greeting message.
"""
system_prompt = """
You are a Chatbot designed to collaborate with users on planning and debugging ML projects.
Your goal is to provide concise and friendly greetings within 50 words, including:
greet_prompt = """
Can you provide concise and friendly greetings within 50 words, including:
1. Infer about the project's purpose or objective.
2. Summarize the previous conversations if it existed.
2. Offering a brief overview of the assistance and support you can provide to the user, such as:
Expand All @@ -96,7 +98,7 @@ def greet(self):
- Providing resources and references for further learning.
Make sure your greeting is inviting and sets a positive tone for collaboration.
"""
self.chat_history.append({"role": "system", "content": system_prompt})
self.chat_history.append({"role": "user", "content": greet_prompt})
greets = self.model.query(
self.chat_history,
function_call='auto',
Expand All @@ -116,7 +118,18 @@ def chat(self, user_prompt):
user_prompt: the user prompt.
"""
text = ''
if self.memory:
table_name = 'mle_chat_' + self.working_dir.split('/')[-1]
query = self.memory.query([user_prompt], table_name=table_name, n_results=1) # TODO: adjust the n_results.
user_prompt += f"""
\nThese reference files and their snippets may be useful for the question:\n\n
"""

for t in query[0]:
snippet, metadata = t.get('text'), t.get('metadata')
user_prompt += f"**File**: {metadata.get('file')}\n**Snippet**: {snippet}\n"
self.chat_history.append({"role": "user", "content": user_prompt})

for content in self.model.stream(
self.chat_history,
function_call='auto',
Expand Down
42 changes: 39 additions & 3 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import questionary
from pathlib import Path
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn, BarColumn

import mle
from mle.server import app
Expand All @@ -18,8 +19,11 @@
startup_web,
print_in_box,
)
from mle.utils import LanceDBMemory, list_files, read_file
from mle.utils import CodeChunker

console = Console()
memory = LanceDBMemory(os.getcwd())


@click.group()
Expand Down Expand Up @@ -127,7 +131,7 @@ def report_local(ctx, path, email, start_date, end_date):
).ask()

return workflow.report_local(os.getcwd(), path, email, start_date=start_date, end_date=end_date)


@cli.command()
@click.option('--model', default=None, help='The model to use for the chat.')
Expand Down Expand Up @@ -187,14 +191,46 @@ def kaggle(

@cli.command()
@click.option('--model', default=None, help='The model to use for the chat.')
def chat(model):
@click.option('--build_mem', is_flag=True, help='Build and enable the local memory for the chat.')
def chat(model, build_mem):
"""
chat: start an interactive chat with LLM to work on your ML project.
"""
if not check_config(console):
return

return workflow.chat(os.getcwd(), model)
if build_mem:
working_dir = os.getcwd()
table_name = 'mle_chat_' + working_dir.split('/')[-1]
source_files = list_files(working_dir, ['*.py']) # TODO: support more file types

chunker = CodeChunker(os.path.join(working_dir, '.mle', 'cache'), 'py')
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console,
) as progress:
process_task = progress.add_task("Processing files...", total=len(source_files))

for file_path in source_files:
raw_code = read_file(file_path)
progress.update(
process_task,
advance=1,
description=f"Adding {os.path.basename(file_path)} to memory..."
)

chunks = chunker.chunk(raw_code, token_limit=100)
memory.add(
texts=list(chunks.values()),
table_name=table_name,
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
)

return workflow.chat(os.getcwd(), model=model, memory=memory)


@cli.command()
Expand Down
1 change: 1 addition & 0 deletions mle/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .cache import *
from .memory import *
from .data import *
from .chunk import *
130 changes: 130 additions & 0 deletions mle/utils/chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Source modified from https://github.com/CintraAI/code-chunker/blob/main/Chunker.py
import tiktoken
from .parser import CodeParser
from abc import ABC, abstractmethod


def count_tokens(string: str, encoding_name: str) -> int:
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens


class Chunker(ABC):
def __init__(self, encoding_name="gpt-4"):
self.encoding_name = encoding_name

@abstractmethod
def chunk(self, content, token_limit):
pass

@abstractmethod
def get_chunk(self, chunked_content, chunk_number):
pass

@staticmethod
def print_chunks(chunks):
for chunk_number, chunk_code in chunks.items():
print(f"Chunk {chunk_number}:")
print("=" * 40)
print(chunk_code)
print("=" * 40)

@staticmethod
def consolidate_chunks_into_file(chunks):
return "\n".join(chunks.values())

@staticmethod
def count_lines(consolidated_chunks):
lines = consolidated_chunks.split("\n")
return len(lines)


class CodeChunker(Chunker):
def __init__(self, cache_dir, file_extension, encoding_name="gpt-4o-mini"):
super().__init__(encoding_name)
self.file_extension = file_extension
self.cache_dir = cache_dir

def chunk(self, code, token_limit) -> dict:
code_parser = CodeParser(self.cache_dir, self.file_extension)
chunks = {}
token_count = 0
lines = code.split("\n")
i = 0
chunk_number = 1
start_line = 0
breakpoints = sorted(code_parser.get_lines_for_points_of_interest(code, self.file_extension))
comments = sorted(code_parser.get_lines_for_comments(code, self.file_extension))
adjusted_breakpoints = []
for bp in breakpoints:
current_line = bp - 1
highest_comment_line = None # Initialize with None to indicate no comment line has been found yet
while current_line in comments:
highest_comment_line = current_line # Update highest comment line found
current_line -= 1 # Move to the previous line

if highest_comment_line: # If a highest comment line exists, add it
adjusted_breakpoints.append(highest_comment_line)
else:
adjusted_breakpoints.append(
bp) # If no comments were found before the breakpoint, add the original breakpoint

breakpoints = sorted(set(adjusted_breakpoints)) # Ensure breakpoints are unique and sorted

while i < len(lines):
line = lines[i]
new_token_count = count_tokens(line, self.encoding_name)
if token_count + new_token_count > token_limit:

# Set the stop line to the last breakpoint before the current line
if i in breakpoints:
stop_line = i
else:
stop_line = max(max([x for x in breakpoints if x < i], default=start_line), start_line)

# If the stop line is the same as the start line, it means we haven't reached a breakpoint yet, and we need to move to the next line to find one
if stop_line == start_line and i not in breakpoints:
token_count += new_token_count
i += 1

# If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
elif stop_line == start_line and i == stop_line:
token_count += new_token_count
i += 1

# If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
elif stop_line == start_line and i in breakpoints:
current_chunk = "\n".join(lines[start_line:stop_line])
if current_chunk.strip(): # If the current chunk is not just whitespace
chunks[chunk_number] = current_chunk # Using chunk_number as key
chunk_number += 1

token_count = 0
start_line = i
i += 1

# If the stop line is different from the start line, it means we're at the end of a block
else:
current_chunk = "\n".join(lines[start_line:stop_line])
if current_chunk.strip():
chunks[chunk_number] = current_chunk # Using chunk_number as key
chunk_number += 1

i = stop_line
token_count = 0
start_line = stop_line
else:
# If the token count is still within the limit, add the line to the current chunk
token_count += new_token_count
i += 1

# Append remaining code, if any, ensuring it's not empty or whitespace
current_chunk_code = "\n".join(lines[start_line:])
if current_chunk_code.strip(): # Checks if the chunk is not just whitespace
chunks[chunk_number] = current_chunk_code # Using chunk_number as key

return chunks

def get_chunk(self, chunked_codebase, chunk_number):
return chunked_codebase[chunk_number]
28 changes: 28 additions & 0 deletions mle/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import re
import os
import json
from typing import Dict, Any


def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None:
"""
Write a dictionary to a markdown file.
:param data: the dictionary to write.
:param file_path: the file path to write the dictionary to.
:return:
"""

def write_item(k, v, indent_level=0):
if isinstance(v, dict):
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
for sub_key, sub_value in v.items():
write_item(sub_key, sub_value, indent_level + 1)
elif isinstance(v, list):
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
for item in v:
md_file.write(f"{' ' * indent_level}- {item}\n")
else:
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
md_file.write(f"{' ' * indent_level}{v}\n")

with open(file_path, 'w') as md_file:
for key, value in data.items():
write_item(key, value)
md_file.write("\n")


def is_markdown_file(file_path):
Expand Down
19 changes: 8 additions & 11 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def reset(self):

class LanceDBMemory:

def __init__(
self,
project_path: str,
):
def __init__(self, project_path: str):
"""
Memory: A base class for memory and external knowledge management.
Args:
Expand All @@ -180,11 +177,11 @@ def __init__(
raise NotImplementedError

def add(
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""
Adds a list of text items to the specified memory table in the database.
Expand All @@ -200,12 +197,12 @@ def add(
List[str]: A list of IDs associated with the added text items.
"""
if isinstance(texts, str):
texts = (texts, )
texts = (texts,)

if metadata is None:
metadata = [None, ] * len(texts)
elif isinstance(metadata, dict):
metadata = (metadata, )
metadata = (metadata,)
else:
assert len(texts) == len(metadata)

Expand Down
Loading

0 comments on commit dd7da4e

Please sign in to comment.