Skip to content

Commit

Permalink
Merge pull request #213 from leeeizhang/lei/kaggle-summarize
Browse files Browse the repository at this point in the history
[MRG] add kaggle requirement summary
  • Loading branch information
huangyz0918 authored Sep 15, 2024
2 parents 6737185 + aeb822b commit 06dca25
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
24 changes: 23 additions & 1 deletion mle/agents/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class SummaryAgent:

def __init__(self, model, github_repo: str, username: str, github_token: str = None, console=None):
def __init__(self, model, github_repo: str = None, username: str = None, github_token: str = None, console=None):
"""
SummaryAgent: summary the workspace provided by the user.
Expand Down Expand Up @@ -133,3 +133,25 @@ def summarize(self):
summary.update({"user_activity": user_activity})

return summary

def kaggle_request_summarize(self, kaggle_overview):
"""
Summarize the kaggle requests.
:params: kaggle_overview: the overview json of kaggle competition
"""
system_prompt = """
You are a seasoned data science expert in Kaggle competitions. Your task is to summarize the
requirements of a specific Kaggle competition in a clear and concise manner. Please ensure that
your summary includes the following aspects:
1. **Overview**: Describe the competition's objective and significance.
2. **Data**: Detail the datasets, including file types, structure, and key features.
3. **Evaluation**: Explain the judging metric and its calculation.
4. **Submission**: Outline the format and requirements for submissions.
5. **Rules**: Highlight important rules, including data usage, team composition, and resources.
"""
chat_history = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": str(kaggle_overview)}
]
return self.model.query(chat_history)
7 changes: 4 additions & 3 deletions mle/workflow/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rich.console import Console
from mle.model import load_model
from mle.utils import ask_text, WorkflowCache
from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent
from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent, SummaryAgent
from mle.integration import KaggleIntegration


Expand Down Expand Up @@ -49,7 +49,8 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None):
if ml_requirement is None:
with console.status("MLE Agent is fetching the kaggle competition overview..."):
overview = kaggle.get_competition_overview(competition)
ml_requirement = f"Finish a kaggle competition: {overview}"
summary = SummaryAgent(model, console=console)
ml_requirement = summary.kaggle_request_summarize(overview)
ca.store("ml_requirement", ml_requirement)

# advisor agent gives suggestions in a report
Expand All @@ -58,7 +59,7 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None):
if advisor_report is None:
advisor = AdviseAgent(model, console)
advisor_report = advisor.interact(
f"[green]User Requirement:[/green] {ml_requirement}\n"
f"[green]Competition Requirement:[/green] {ml_requirement}\n"
f"Dataset is downloaded in path: {dataset}"
)
ca.store("advisor_report", advisor_report)
Expand Down

0 comments on commit 06dca25

Please sign in to comment.