Skip to content

Commit

Permalink
✨ feat: added load model function
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyz0918 committed Apr 19, 2024
1 parent 2cd88db commit 4eb0347
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
28 changes: 28 additions & 0 deletions agent/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import agent
from agent.utils import *
from agent.model import OpenAIModel

console = Console()
# avoid the tokenizers parallelism issue
Expand Down Expand Up @@ -77,6 +78,25 @@ def build_config(general: bool = False):
configuration.write_section(platform, platform_config)


def load_model():
"""
load_model: load the model based on the configuration.
"""
configuration = Config()
config_dict = configuration.read()
plat = config_dict['general']['platform']

model = None
if plat == LLM_TYPE_OPENAI:
model = OpenAIModel(
api_key=config_dict[CONFIG_SEC_GENERAL][CONFIG_SEC_API_KEY],
model=config_dict[LLM_TYPE_OPENAI].get('model'),
temperature=float(config_dict[LLM_TYPE_OPENAI]['temperature'])
)

return model


@click.group(cls=DefaultCommandGroup)
@click.version_option(version=agent.__version__)
def cli():
Expand Down Expand Up @@ -111,19 +131,27 @@ def go():
go: start the working your ML project.
"""
configuration = Config()
model = load_model()
console.log("Welcome to MLE-Agent! :sunglasses:")
if model:
console.log(f"Model loaded: {model.model_type}, {model.model}")
console.line()

if configuration.read().get('project') is None:
console.log("You have not set up a project yet.")
console.log("Please create a new project first using 'mle new <project_name>' command.")
return

# ask for the project language.
console.log("> Current project:", configuration.read()['project']['path'])
if configuration.read()['project'].get('lang') is None:
lang = questionary.text("What is your major language for this project?").ask()
configuration.write_section(CONFIG_SEC_PROJECT, {'lang': lang})

console.log("> Project language:", configuration.read()['project']['lang'])

# ask for the project description.


@cli.command()
@click.argument('name')
Expand Down
2 changes: 2 additions & 0 deletions agent/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import *
from .openai import *
6 changes: 1 addition & 5 deletions agent/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,5 @@ def __init__(self):
self.model_type = None

@abstractmethod
def load_data(self, data):
pass

@abstractmethod
def load_model(self, model_path):
def chat(self, context: str, text: str):
pass
38 changes: 31 additions & 7 deletions agent/model/openai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import importlib.util

from .base import Model
from agent.const import LLM_TYPE_OPENAI


class OpenAIModel(Model):
def __init__(self, api_key, version, temperature):
def __init__(self, api_key, model, temperature):
"""
Initialize the OpenAI model.
Args:
api_key (str): The OpenAI API key.
version (str): The model version.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()
Expand All @@ -26,12 +27,35 @@ def __init__(self, api_key, version, temperature):
"More information, please refer to: https://openai.com/product"
)

self.version = version
self.model = model
self.model_type = LLM_TYPE_OPENAI
self.temperature = temperature
self.client = self.OpenAI(api_key=api_key)

def load_data(self, data):
pass
def chat(self, context: str, text: str):
"""
Chat with the model.
Args:
context (str): The context (chat history) prompt.
text (str): The text prompt.
"""
try:
chat_history = [
{"role": "system", "content": context},
{"role": "user", "content": text}
]

completion = self.client.chat.completions.create(
model=self.version,
messages=chat_history,
temperature=self.temperature
)

def load_model(self, model_path):
pass
response = completion.choices[0].message.content
return response
except self.RateLimitError as e:
print("Rate limit exceeded. Please try again later.")
print(f"Error message: {e}")
except Exception as e:
print("OpenAI error occurred.")
print(f"Error message: {e}")

0 comments on commit 4eb0347

Please sign in to comment.