Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ask users to provide absolute path and test deepseek #282

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mle/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def plan(self, user_prompt):
self.chat_history,
response_format={"type": "json_object"}
)

self.chat_history.append({"role": "assistant", "content": text})

try:
Expand Down
40 changes: 38 additions & 2 deletions mle/model/ollama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
import re

from .common import Model

Expand Down Expand Up @@ -27,29 +28,64 @@ def __init__(self, model, host_url=None):
"More information, please refer to: https://github.com/ollama/ollama-python"
)

def _clean_think_tags(self, text):
"""
Remove content between <think> tags and empty think tags from the text.
Args:
text (str): The input text to clean.
Returns:
str: The cleaned text with think tags and their content removed.
"""
# Remove content between <think> tags
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# Remove empty think tags
text = re.sub(r'<think></think>', '', text)
return text.strip()

def _process_message(self, message, **kwargs):
"""
Process the message before sending to the model.
Args:
message: The message to process.
**kwargs: Additional arguments.
Returns:
dict: The processed message.
"""
if isinstance(message, dict) and 'content' in message:
message['content'] = self._clean_think_tags(message['content'])
return message

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
**kwargs: Additional arguments for the model.
Returns:
str: The model's response.
"""

# Check if 'response_format' exists in kwargs
format = None
if 'response_format' in kwargs and kwargs['response_format'].get('type') == 'json_object':
format = 'json'

return self.client.chat(model=self.model, messages=chat_history, format=format)['message']['content']
response = self.client.chat(model=self.model, messages=chat_history, format=format)
return self._clean_think_tags(response['message']['content'])

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
**kwargs: Additional arguments for the model.
Yields:
str: Chunks of the model's response.
"""

for chunk in self.client.chat(
model=self.model,
messages=chat_history,
stream=True
):
yield chunk['message']['content']
yield self._clean_think_tags(chunk['message']['content'])
2 changes: 1 addition & 1 deletion mle/workflow/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def baseline(work_dir: str, model=None):
dataset = ca.resume("dataset")
if dataset is None:
advisor = AdviseAgent(model, console)
dataset = ask_text("Please provide your dataset information (a public dataset name or a local file path)")
dataset = ask_text("Please provide your dataset information (a public dataset name or a local absolute filepath)")
if not dataset:
print_in_box("The dataset is empty. Aborted", console, title="Error", color="red")
return
Expand Down