diff --git a/mle/agents/summarizer.py b/mle/agents/summarizer.py index 5b2abe7..551a2eb 100644 --- a/mle/agents/summarizer.py +++ b/mle/agents/summarizer.py @@ -7,7 +7,7 @@ class SummaryAgent: - def __init__(self, model, github_repo: str, username: str, github_token: str = None, console=None): + def __init__(self, model, github_repo: str = None, username: str = None, github_token: str = None, console=None): """ SummaryAgent: summary the workspace provided by the user. @@ -133,3 +133,25 @@ def summarize(self): summary.update({"user_activity": user_activity}) return summary + + def kaggle_request_summarize(self, kaggle_overview): + """ + Summarize the kaggle requests. + :params: kaggle_overview: the overview json of kaggle competition + """ + system_prompt = """ + You are a seasoned data science expert in Kaggle competitions. Your task is to summarize the + requirements of a specific Kaggle competition in a clear and concise manner. Please ensure that + your summary includes the following aspects: + + 1. **Overview**: Describe the competition's objective and significance. + 2. **Data**: Detail the datasets, including file types, structure, and key features. + 3. **Evaluation**: Explain the judging metric and its calculation. + 4. **Submission**: Outline the format and requirements for submissions. + 5. **Rules**: Highlight important rules, including data usage, team composition, and resources. + """ + chat_history = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": str(kaggle_overview)} + ] + return self.model.query(chat_history) diff --git a/mle/workflow/kaggle.py b/mle/workflow/kaggle.py index f342f94..e9e2d6e 100644 --- a/mle/workflow/kaggle.py +++ b/mle/workflow/kaggle.py @@ -6,7 +6,7 @@ from rich.console import Console from mle.model import load_model from mle.utils import ask_text, WorkflowCache -from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent +from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent, SummaryAgent from mle.integration import KaggleIntegration @@ -49,7 +49,8 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None): if ml_requirement is None: with console.status("MLE Agent is fetching the kaggle competition overview..."): overview = kaggle.get_competition_overview(competition) - ml_requirement = f"Finish a kaggle competition: {overview}" + summary = SummaryAgent(model, console=console) + ml_requirement = summary.kaggle_request_summarize(overview) ca.store("ml_requirement", ml_requirement) # advisor agent gives suggestions in a report @@ -58,7 +59,7 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None): if advisor_report is None: advisor = AdviseAgent(model, console) advisor_report = advisor.interact( - f"[green]User Requirement:[/green] {ml_requirement}\n" + f"[green]Competition Requirement:[/green] {ml_requirement}\n" f"Dataset is downloaded in path: {dataset}" ) ca.store("advisor_report", advisor_report)