-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from MLSysOps/hz/platform-selection
[MRG] [POC] using langchain to quickly POC training and eval steps
- Loading branch information
Showing
12 changed files
with
581 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,3 +164,4 @@ cython_debug/ | |
|
||
# cache | ||
.rich-chat.history | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.json | ||
*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from langchain_openai import ChatOpenAI | ||
from cfg import load_config | ||
|
||
|
||
from enum import Enum | ||
from langchain.output_parsers import EnumOutputParser | ||
from langchain.prompts import PromptTemplate | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.runnables import chain | ||
|
||
# Define an Enum for ML Development Stages | ||
class MLDevelopmentStage(Enum): | ||
PROBLEM_DEFINITION = "Problem Definition" | ||
DATA_COLLECTION = "Data Collection" | ||
DATA_ENGINEERING = "Data Engineering" | ||
MODEL_TRAINING = "Model Training" | ||
MODEL_EVALUATION = "Model Evaluation" | ||
MODEL_DEPLOYMENT = "Model Deployment" | ||
|
||
@chain | ||
def analyze_ml_development_stage(input: str, llm: BaseChatModel) -> MLDevelopmentStage: | ||
""" | ||
Analyze the user's current ML development stage based on their input description. | ||
Args: | ||
input (str): The user's input describing their current work. | ||
llm (BaseChatModel): The language model used for processing the input. | ||
Returns: | ||
MLDevelopmentStage: Enum representing the identified ML development stage. | ||
""" | ||
parser = EnumOutputParser(enum=MLDevelopmentStage) | ||
prompt = PromptTemplate( | ||
template=""" | ||
Welcome to the ML development stage analysis tool. We cover various stages like: | ||
- Problem Definition | ||
- Data Collection | ||
- Data Engineering | ||
- Model Training | ||
- Model Evaluation | ||
- Model Deployment | ||
Please provide some information about what you are currently working on in your ML project, | ||
and I will help identify the stage you are likely at. | ||
The user's input description is: {input} | ||
Identified ML development stage is: | ||
""", | ||
input_variables=["input"], | ||
partial_variables={"format_instructions": parser.get_format_instructions()}, | ||
) | ||
|
||
chain = prompt | llm | parser | ||
return chain.invoke({"input": input}) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
config = load_config('open_ai_key.json') | ||
OPENAI_API_KEY = config["OPENAI_API_KEY"] | ||
# Example usage: | ||
llm = ChatOpenAI(api_key=OPENAI_API_KEY) | ||
# Example corrected usage: | ||
stage_result = analyze_ml_development_stage.invoke( | ||
input="I am currently selecting features and cleaning data", | ||
llm=llm # Assuming 'llm' is your instantiated language model | ||
) | ||
|
||
# To access the result: | ||
print("You are likely at the:", stage_result.name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from langchain_openai import ChatOpenAI | ||
from langchain.prompts import PromptTemplate | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.runnables import chain | ||
from langchain_core.output_parsers import StrOutputParser | ||
from snowflake.connector import connect | ||
|
||
from cfg import load_config | ||
|
||
def snowflake_config(config): | ||
snowflake_config = config["snowflake"] | ||
print(snowflake_config) | ||
# Create the Snowflake URL | ||
conn = connect( | ||
user=snowflake_config["user"], | ||
password=snowflake_config["password"], | ||
account=snowflake_config["account"], | ||
warehouse=snowflake_config["warehouse"], | ||
database=snowflake_config["database"], | ||
schema=snowflake_config["schema"] | ||
) | ||
|
||
return conn | ||
|
||
|
||
def data_engineering_data_loading(): | ||
# 1. Welcome the user to the MLE-agent | ||
print("Welcome to the MLE-agent!") | ||
|
||
# 2. Inform them about the current stage | ||
print("You are currently in the Data Engineering stage.") | ||
|
||
# 3. Ask the user to choose a data store | ||
print("Please choose a data store by entering the corresponding number:") | ||
print("1. Snowflake") | ||
print("2. Databricks") | ||
print("3. AWS S3") | ||
|
||
# 4. Process user input | ||
choice = input("Enter your choice (1, 2, or 3): ") | ||
|
||
# 5. Respond according to the user's choice | ||
data_stores = { | ||
'1': 'Snowflake', | ||
'2': 'Databricks', | ||
'3': 'AWS S3' | ||
} | ||
|
||
# Validate the user input | ||
if choice in data_stores: | ||
print(f"MLE-agent will now help you to load data from {data_stores[choice]}.") | ||
else: | ||
print("Invalid choice. Please run the program again and select 1, 2, or 3.") | ||
|
||
|
||
@chain | ||
def data_load_agent(input: str, llm: BaseChatModel): | ||
""" | ||
Analyze the user's current ML development stage based on their input description. | ||
Args: | ||
input (str): The user's input describing their current work. | ||
llm (BaseChatModel): The language model used for processing the input. | ||
Returns: | ||
MLDevelopmentStage: Enum representing the identified ML development stage. | ||
""" | ||
output_parser = StrOutputParser() | ||
prompt = PromptTemplate( | ||
template=""" | ||
You play as a professional data scientist. You are currently in the Data Engineering stage. You will understand | ||
users input and generate code to help users load data from Snowflake step by step. Please import the necessary | ||
packages to make sure the code can run successfully. | ||
First, you should understand users' input and generate a SQL query to load data from Snowflake. | ||
Do not add ; at the end of the query. | ||
The snowflake credentials are stored in ../snowflake_key.json. Please make sure to use the file to | ||
load the credentials and connect to snowflake. | ||
Then, you should write code to execute the SQL query and load the data into a DataFrame. | ||
Remember to close the connection. | ||
Finally, you will write code to show some data from the DataFrame. | ||
The output should be a pure python code block including above both that can be run to load data from Snowflake. | ||
Do not include ``` in front of and after the code block. | ||
The user's input description is: {input} | ||
The output: | ||
""", | ||
input_variables=["input"] | ||
) | ||
|
||
chain = prompt | llm | output_parser | ||
return chain.invoke({"input": input}) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Assuming your configuration file is named 'config.json' | ||
|
||
data_engineering_data_loading() | ||
|
||
config = load_config('open_ai_key.json') | ||
OPENAI_API_KEY = config["OPENAI_API_KEY"] | ||
|
||
llm = ChatOpenAI(api_key=OPENAI_API_KEY) | ||
|
||
# prompt = "Show me top 5 records from IMDB_TRAIN" | ||
prompt = "I want analyze IMDB_TRAIN dataset and use it to train a sentiment analysis model" | ||
|
||
data_load_code = data_load_agent.invoke( | ||
input=prompt, | ||
llm=llm # Assuming 'llm' is your instantiated language model | ||
) | ||
|
||
# we can visualize what sql query is generated by the LLM | ||
print(data_load_code) | ||
|
||
with open("imdb_project/data_load_GENERATED.py", "w") as py_file: | ||
py_file.write(data_load_code) | ||
|
||
print(data_load_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.runnables import chain | ||
from langchain_openai import ChatOpenAI | ||
|
||
from cfg import load_config | ||
|
||
|
||
@chain | ||
def vis_agent(input: str, llm: BaseChatModel): | ||
""" | ||
Analyze the user's current ML development stage based on their input description. | ||
Args: | ||
input (str): The user's input describing their current work. | ||
llm (BaseChatModel): The language model used for processing the input. | ||
Returns: | ||
MLDevelopmentStage: Enum representing the identified ML development stage. | ||
""" | ||
output_parser = StrOutputParser() | ||
prompt = PromptTemplate( | ||
template=""" | ||
System: You play as a professional data scientist. You are currently in the Data Engineering stage. You can read data | ||
and then use tools like pandas, numpy, and matplotlib to visualize the data. | ||
The first line of input is a file name with a path. | ||
First, you should write code to read the data from {input_file_path} base on its suffix using proper tools like pandas, | ||
etc. | ||
The second line of input is the header of the data you want to visualize. | ||
Second, based on the {sample_data}, you should write code to visualize the data using matplotlib, seaborn, or any other. | ||
The code must be tailored to the user's input and the text in the figure must be too. | ||
Your output should be purely python code that can be run to visualize the data. Please do not include | ||
```in front | ||
of and after the code block. | ||
input: {input_file_path} \n {sample_data} | ||
Answer as a python code block: | ||
""", | ||
input_variables=["input_file_path", "sample_data"] | ||
) | ||
|
||
chain = prompt | llm | output_parser | ||
return chain.invoke({"input_file_path": input.split("\n")[0], "sample_data": input.split("\n")[1:]}) | ||
|
||
|
||
if __name__ == "__main__": | ||
config = load_config('open_ai_key.json') | ||
|
||
OPENAI_API_KEY = config["OPENAI_API_KEY"] | ||
|
||
llm = ChatOpenAI(api_key=OPENAI_API_KEY) | ||
|
||
prompt = ("output.csv \n TEXT, LABEL") | ||
|
||
vis_code = vis_agent.invoke( | ||
input=prompt, | ||
llm=llm # Assuming 'llm' is your instantiated language model | ||
) | ||
|
||
with open("imdb_project/data_vis_GENERATED.py", "w") as py_file: | ||
py_file.write(vis_code) | ||
|
||
print(vis_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.runnables import chain | ||
from langchain_openai import ChatOpenAI | ||
|
||
from cfg import load_config | ||
|
||
|
||
@chain | ||
def model_selection_agent(input: str, llm: BaseChatModel): | ||
""" | ||
Analyze the user's current ML development stage based on their input description. | ||
Args: | ||
input (str): The user's input describing their current work. | ||
llm (BaseChatModel): The language model used for processing the input. | ||
Returns: | ||
MLDevelopmentStage: Enum representing the identified ML development stage. | ||
""" | ||
output_parser = StrOutputParser() | ||
prompt = PromptTemplate( | ||
template=""" | ||
System: You play as a professional machine learning engineer. You are currently in the model training stage. | ||
You need to select and load a model for users to build a baseline given users query. | ||
Your output should be purely python code that can be run to load the model from appropriate library. | ||
The user's input description is: {query} | ||
The python code generated is: | ||
""", | ||
input_variables=["query"] | ||
) | ||
|
||
chain = prompt | llm | output_parser | ||
return chain.invoke({"query": input}) | ||
|
||
|
||
if __name__ == "__main__": | ||
config = load_config('open_ai_key.json') | ||
|
||
OPENAI_API_KEY = config["OPENAI_API_KEY"] | ||
|
||
llm = ChatOpenAI(api_key=OPENAI_API_KEY) | ||
|
||
prompt = ("I want to use some lightweight model to build a sentiment analysis baseline on my dataset." | ||
"make sure you only import necessary packages" | ||
"make sure you also load the corresponding preprocessor like tokenizer and model.") | ||
|
||
model_selection_code = model_selection_agent.invoke( | ||
input=prompt, | ||
llm=llm # Assuming 'llm' is your instantiated language model | ||
) | ||
|
||
with open("imdb_project/model_selection_GENERATED.py", "w") as py_file: | ||
py_file.write(model_selection_code) | ||
|
||
print(model_selection_code) |
Oops, something went wrong.