From dc660cb128158f3b1a6033e7efed5f853a5e8f9e Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 18:18:14 +0800 Subject: [PATCH 01/16] improved image generation logic --- .gitignore | 2 + source/panel/image_generation.py | 470 ++++++++++++++++++++++++------- source/panel/requirements.txt | 9 +- 3 files changed, 383 insertions(+), 98 deletions(-) diff --git a/.gitignore b/.gitignore index 97bc48be..7701115c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,5 @@ package-lock.json **/model/embedding/model/models--BAAI--bge-large-zh-v1.5 **/model/embedding/model/models--csdc-atl--buffer-embedding-002 **/model/instruct/model/models--csdc-atl--buffer-instruct-InternLM-001 +.env_sd +venv diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index b0bb8624..417430de 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -1,37 +1,43 @@ -import os -import boto3 import json import logging +import os import time -import json +from datetime import datetime +from typing import List +import boto3 +import requests +import streamlit as st +from dotenv import load_dotenv from langchain import PromptTemplate -from langchain.llms.bedrock import Bedrock -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple -from langchain.docstore.document import Document -from langchain.embeddings import OpenAIEmbeddings -from langchain.vectorstores import FAISS -from dotenv import load_dotenv # load .env file with specific name load_dotenv(dotenv_path='.env_sd') logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -import streamlit as st -import requests -import time - # load the URL from .env file -API_KEY = os.getenv("API_KEY") # 'https://xxxx.execute-api.us-west-2.amazonaws.com/prod/' COMMAND_API_URL = os.getenv("COMMON_API_URL") -GENERATE_API_URL = COMMAND_API_URL + "inference-api/inference" +API_KEY = os.getenv("API_KEY") +API_USERNAME = os.getenv("API_USERNAME") +# The service support varies in different regions +BEDROCK_REGION = os.getenv("BEDROCK_REGION") + +GENERATE_API_URL = COMMAND_API_URL + "inference/v2" STATUS_API_URL = COMMAND_API_URL + "inference/get-inference-job" -IMAGE_API_URL = COMMAND_API_URL + "inference/get-inference-job-param-output" +PARAM_API_URL = COMMAND_API_URL + "inference/get-inference-job-param-output" +IMAGE_API_URL = COMMAND_API_URL + "inference/get-inference-job-image-output" + +support_model_list = ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", + "x2AnimeFinal_gzku.safetensors", "v1-5-pruned-emaonly.safetensors"] -def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, endpoint_name: str = "default-endpoint-for-llm-bot"): +default_models = ["v1-5-pruned-emaonly.safetensors"] + + +def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, + endpoint_name: str = "default-endpoint-for-llm-bot"): headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', @@ -41,29 +47,32 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in "instance_type": instance_type, "initial_instance_count": initial_instance_count, "endpoint_name": endpoint_name - } + } # https://.execute-api..amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint - res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody) + res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody) logger.info("deploy_sagemaker_endpoint: {}".format(res.json())) + def upload_model(): pass + def get_bedrock_client(): # specify the profile_name to call the bedrock api if needed - bedrock_client = boto3.client('bedrock-runtime') + bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION) return bedrock_client + def claude_template(initial_prompt: str, placeholder: str): sd_prompt = PromptTemplate( - input_variables=["initial_prompt", "placeholder"], + input_variables=["initial_prompt", "placeholder"], template=""" - Transform the input prompt {initial_prompt} into a detailed prompt for an image generation model, describing the scene with vivid and specific attributes that enhance the original concept, only adjective and noun are allowed, verb and adverb are not allowed, each words speperated by comma. - Generate a negative prompt that specifies what should be avoided in the image, including any elements that contradict the desired style or tone. - Recommend a list of suitable models from the stable diffusion lineup that best match the style and content described in the detailed prompt. - Other notes please refer to {placeholder} - The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: "sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors": + The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors"]: [Positive Prompt: , Negative Prompt: , @@ -80,16 +89,17 @@ def claude_template(initial_prompt: str, placeholder: str): """ ) # Pass in values to the input variables - prompt = sd_prompt.format(initial_prompt="a cute dog", placeholder="") + prompt = sd_prompt.format(initial_prompt="a cute dog", placeholder="") return prompt -def get_llm_processed_prompts(initial_prompt): + +def get_llm_processed_prompts(initial_prompt: str): # get the bedrock client bedrock_client = get_bedrock_client() prompt = claude_template(initial_prompt, '') prompt = "\n\nHuman:{}".format(prompt) + "\n\nAssistant:" - logger.debug("final prompt: {}".format(prompt)) + logger.debug("final prompt: {}".format(prompt)) body = json.dumps({ "prompt": prompt, "temperature": 0.7, @@ -99,105 +109,377 @@ def get_llm_processed_prompts(initial_prompt): "stop_sequences": ["\n\nHuman:"] }) # note v2 is not output chinese characters - modelId = "anthropic.claude-v2" + model_id = "anthropic.claude-v2" accept = "*/*" - contentType = "application/json" + content_type = "application/json" response = bedrock_client.invoke_model( - body=body, modelId=modelId, accept=accept, contentType=contentType + body=body, modelId=model_id, accept=accept, contentType=content_type ) response_body = json.loads(response.get("body").read()) raw_completion = response_body.get("completion").split('\n') logger.info("raw_completion: {}".format(raw_completion)) - # TODO: extract positive prompt, negative prompt and model list from the raw_completion + positive_prompt = "" + negative_prompt = "" + model_list = [] + + # todo need to check the length of raw_completion + if len(raw_completion) == 3: + positive_prompt = raw_completion[0].split(':')[1].strip() + negative_prompt = raw_completion[1].split(':')[1].strip() + model_list = raw_completion[2].split(':')[1].strip().split(',') + + if len(raw_completion) == 5: + positive_prompt = raw_completion[2].split(':')[1].strip() + negative_prompt = raw_completion[3].split(':')[1].strip() + model_list = raw_completion[4].split(':')[1].strip().split(',') + + if len(model_list) > 0: + model_list = model_list[0].replace('[', '').replace(']', '').replace('"', '').split(' ') - logger.info("positive_prompt: {}".format(positive_prompt)) - logger.info("negative_prompt: {}".format(negative_prompt)) - logger.info("model_list: {}".format(model_list)) return positive_prompt, negative_prompt, model_list -def generate_image(endpoint_name: str, positive_prompt: str, negative_prompt: str, model: List[str]): - # Construct the API request (this is a placeholder) + +def generate_image(positive_prompts: str, negative_prompts: str, model: List[str]): + st.write("Generate Image Process:") + + # set progress bar for user experience + progess = 5 + bar = st.progress(progess) + + job = create_inference_job(model) + progess += 5 + bar.progress(progess) + + inference = job["inference"] + + upload_inference_job_api_params(inference["api_params_s3_upload_url"], positive_prompts, negative_prompts) + progess += 5 + bar.progress(progess) + + run_inference_job(inference["id"]) + progess += 5 + bar.progress(progess) + + while True: + status_response = get_inference_job(inference["id"]) + if progess < 90: + progess += 10 + bar.progress(progess) + if status_response['status'] == 'succeed': + image_url = get_inference_image_output(inference["id"])[0] + st.image(image_url, caption=positive_prompts, use_column_width=True) + break + elif status_response['status'] == 'failed': + st.error("Image generation failed.") + break + else: + time.sleep(1) + + bar.progress(100) + return inference["id"] + + +def get_inference_job(inference_id: str): headers = { "Content-Type": "application/json", "Accept": "application/json", 'x-api-key': API_KEY } + + job = requests.get(STATUS_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def get_inference_param_output(inference_id: str): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + job = requests.get(PARAM_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def get_inference_image_output(inference_id: str): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + job = requests.get(IMAGE_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def create_inference_job(models: List[str]): + if len(models) == 0: + models = default_models + + if not set(models).issubset(set(support_model_list)): + models = default_models + st.warning("use default model {} because LLM recommend not in support list".format(models)) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + # todo use default api params body = { + "user_id": API_USERNAME, "task_type": "txt2img", "models": { - model + "Stable-diffusion": models, + "embeddings": [] }, - "sagemaker_endpoint_name": endpoint_name, - "prompt": positive_prompt, - "negative_prompt": negative_prompt, - "denoising_strength": 0.75 + "filters": { + "createAt": datetime.now().timestamp(), + "creator": "sd-webui" + } } - response = requests.post(COMMAND_API_URL + "inference-api/inference", headers = headers, json = body) - return response.json() + job = requests.post(GENERATE_API_URL, headers=headers, json=body) + user_models = [] + for model in job.json()['inference']['models']: + user_models.append(model['name'][0]) + st.write("use models: {}".format(user_models)) + return job.json() + -def check_image_status(inference_id: str): - """Check the status of the image generation.""" +def run_inference_job(inference_id: str): headers = { - 'Accept': 'application/json', + "Content-Type": "application/json", + "Accept": "application/json", 'x-api-key': API_KEY } - # TODO, the schema is not completed according to the API document - response = requests.get(GENERATE_API_URL, headers = headers) - return response.json() -def get_image_url(inference_id): - """Get the URL of the generated image.""" - response = requests.get(f"{IMAGE_API_URL}/{inference_id}") - return response.json() + job = requests.put(COMMAND_API_URL + 'inference/v2/' + inference_id + '/run', headers=headers) -def streamlit(): - # Streamlit layout - st.title("Image Generation Application") + return job.json() - # User input - prompt = st.text_input("Enter a prompt for the image:", "A cute dog") - # Button to start the image generation process - if st.button('Generate Image'): +def upload_inference_job_api_params(s3_url, positive: str, negative: str): + # todo use default api params + api_params = { + "prompt": positive, + "negative_prompt": negative, + "styles": [], + "seed": -1, + "subseed": -1, + "subseed_strength": 0.0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "sampler_name": "DPM++ 2M Karras", + "batch_size": 1, + "n_iter": 1, + "steps": 20, + "cfg_scale": 7.0, + "width": 512, + "height": 512, + "restore_faces": None, + "tiling": None, + "do_not_save_samples": False, + "do_not_save_grid": False, + "eta": None, + "denoising_strength": None, + "s_min_uncond": 0.0, + "s_churn": 0.0, + "s_tmax": "Infinity", + "s_tmin": 0.0, + "s_noise": 1.0, + "override_settings": {}, + "override_settings_restore_afterwards": True, + "refiner_checkpoint": None, + "refiner_switch_at": None, + "disable_extra_networks": False, + "comments": {}, + "enable_hr": False, + "firstphase_width": 0, + "firstphase_height": 0, + "hr_scale": 2.0, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_checkpoint_name": None, + "hr_sampler_name": None, + "hr_prompt": "", + "hr_negative_prompt": "", + "sampler_index": "DPM++ 2M Karras", + "script_name": None, + "script_args": [], + "send_images": True, + "save_images": False, + "alwayson_scripts": { + "refiner": { + "args": [False, "", 0.8] + }, + "seed": { + "args": [-1, False, -1, 0, 0, 0] + }, + "controlnet": { + "args": [ + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + }, + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + }, + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + } + ] + }, + "segment anything": { + "args": [ + False, + False, + 0, + None, + [], + 0, + False, + [], + [], + False, + 0, + 1, + False, + False, + 0, + None, + [], + -2, + False, + [], + False, + 0, + None, + None + ] + }, + "extra options": { + "args": [] + } + } + } + + json_string = json.dumps(api_params) + response = requests.put(s3_url, data=json_string) + response.raise_for_status() + return response + + +def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): + negative = "" + models = default_models + if llm_prompt is True: + st.write("Wait for LLM to process the prompt...") positive_prompt, negative_prompt, model_list = get_llm_processed_prompts(prompt) - # Assuming the first model is chosen for simplicity - # chosen_model = model_list.split('\n')[0] - - # Generate the detailed prompt - response = generate_image(positive_prompt, negative_prompt, model_list) - - # Display image (placeholder for actual image retrieval logic) - st.image("https://picsum.photos/200", caption=positive_prompt) - - if response.status_code == 200: - inference_id = response.json()['inference_id'] - # Check the status periodically - with st.empty(): - while True: - status_response = check_image_status(inference_id) - if status_response['status'] == 'succeeded': - image_url = get_image_url(inference_id)['url'] - st.image(image_url) - break - elif status_response['status'] == 'failed': - st.error("Image generation failed.") - break - else: - st.text("Waiting for the image to be generated...") - time.sleep(5) # Sleep for a while before checking the status again - else: - st.error("Failed to start the image generation process.") + + # if prompt is empty, use default + if positive_prompt != "": + initial_prompt = positive_prompt + st.write("positive_prompt:") + st.info("{}".format(positive_prompt)) + + if negative_prompt != "": + negative = negative_prompt + st.write("negative_prompt:") + st.info("{}".format(negative_prompt)) + + if len(model_list) > 0 and model_list[0] != "": + models = model_list + st.write("model_list:") + st.info("{}".format(model_list)) + + inference_id = generate_image(initial_prompt, negative, models) + + return inference_id + # main entry point for debugging +# python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": - # deploy_sagemaker_endpoint() - # upload_model() - positive_prompt, negative_prompt, model_list = get_llm_processed_prompts("a cute dog") - # The endpoint fixed for now, since the deploy_sagemaker_endpoint() won't return the endpoint name - response = generate_image("default-endpoint-for-llm-bot", positive_prompt, negative_prompt, model_list) - logger.info("generate image response: {}".format(response)) - - # python -m streamlit run image-generation.py --server.port 8088 - # streamlit() + try: + # Streamlit layout + st.title("Image Generation Application") + + # User input + prompt = st.text_input("Enter a prompt for the image:", "A cute dog") + + button_disabled = False + if not button_disabled: + if st.button('Generate Image'): + button_disabled = True + st.empty() + + st.subheader("Image without LLM") + generate_llm_image(prompt, False) + + st.subheader("Image with LLM") + generate_llm_image(prompt) + except Exception as e: + logger.exception(e) + st.error("Failed to start the image generation process.") + raise e diff --git a/source/panel/requirements.txt b/source/panel/requirements.txt index 04af6ac9..bd5208b4 100644 --- a/source/panel/requirements.txt +++ b/source/panel/requirements.txt @@ -1,6 +1,7 @@ -python-dotenv +python-dotenv PyPDF2 -streamlit -langchain +streamlit +langchain openai -tiktoken \ No newline at end of file +tiktoken +boto3 From ec8fd9a7b1e8d39d21e226b2c9afe0560b30fd85 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 18:18:14 +0800 Subject: [PATCH 02/16] improved image generation logic --- .gitignore | 2 + source/panel/image_generation.py | 522 +++++++++++++++++++++++-------- source/panel/requirements.txt | 9 +- 3 files changed, 406 insertions(+), 127 deletions(-) diff --git a/.gitignore b/.gitignore index 97bc48be..7701115c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,5 @@ package-lock.json **/model/embedding/model/models--BAAI--bge-large-zh-v1.5 **/model/embedding/model/models--csdc-atl--buffer-embedding-002 **/model/instruct/model/models--csdc-atl--buffer-instruct-InternLM-001 +.env_sd +venv diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index b0bb8624..884066f5 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -1,37 +1,51 @@ -import os -import boto3 import json import logging +import os import time -import json +from datetime import datetime +from typing import List -from langchain import PromptTemplate +from langchain.prompts import PromptTemplate from langchain.llms.bedrock import Bedrock from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from langchain.embeddings import OpenAIEmbeddings +from langchain.chains import ConversationChain from langchain.vectorstores import FAISS +from langchain.memory import ConversationBufferMemory +import boto3 +import requests +import streamlit as st from dotenv import load_dotenv +from langchain import PromptTemplate + # load .env file with specific name load_dotenv(dotenv_path='.env_sd') logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -import streamlit as st -import requests -import time - # load the URL from .env file -API_KEY = os.getenv("API_KEY") # 'https://xxxx.execute-api.us-west-2.amazonaws.com/prod/' COMMAND_API_URL = os.getenv("COMMON_API_URL") -GENERATE_API_URL = COMMAND_API_URL + "inference-api/inference" +API_KEY = os.getenv("API_KEY") +API_USERNAME = os.getenv("API_USERNAME") +# The service support varies in different regions +BEDROCK_REGION = os.getenv("BEDROCK_REGION") + +GENERATE_API_URL = COMMAND_API_URL + "inference/v2" STATUS_API_URL = COMMAND_API_URL + "inference/get-inference-job" -IMAGE_API_URL = COMMAND_API_URL + "inference/get-inference-job-param-output" +PARAM_API_URL = COMMAND_API_URL + "inference/get-inference-job-param-output" +IMAGE_API_URL = COMMAND_API_URL + "inference/get-inference-job-image-output" + +support_model_list = ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", + "x2AnimeFinal_gzku.safetensors", "v1-5-pruned-emaonly.safetensors"] -def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, endpoint_name: str = "default-endpoint-for-llm-bot"): +default_models = ["v1-5-pruned-emaonly.safetensors"] + + +def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, + endpoint_name: str = "default-endpoint-for-llm-bot"): headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', @@ -41,29 +55,37 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in "instance_type": instance_type, "initial_instance_count": initial_instance_count, "endpoint_name": endpoint_name - } + } # https://.execute-api..amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint - res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody) + res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody) logger.info("deploy_sagemaker_endpoint: {}".format(res.json())) + def upload_model(): pass -def get_bedrock_client(): + +def get_bedrock_llm(): # specify the profile_name to call the bedrock api if needed - bedrock_client = boto3.client('bedrock-runtime') - return bedrock_client - -def claude_template(initial_prompt: str, placeholder: str): - sd_prompt = PromptTemplate( - input_variables=["initial_prompt", "placeholder"], - template=""" - - Transform the input prompt {initial_prompt} into a detailed prompt for an image generation model, describing the scene with vivid and specific attributes that enhance the original concept, only adjective and noun are allowed, verb and adverb are not allowed, each words speperated by comma. + bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION) + + model_id = "anthropic.claude-v2" + cl_llm = Bedrock( + model_id=model_id, + client=bedrock_client, + model_kwargs={"max_tokens_to_sample": 1000}, + ) + return cl_llm + +sd_prompt = PromptTemplate.from_template( + """ + Human: + - Transform the input prompt {input} into a detailed prompt for an image generation model, describing the scene with vivid and specific attributes that enhance the original concept, only adjective and noun are allowed, verb and adverb are not allowed, each words speperated by comma. - Generate a negative prompt that specifies what should be avoided in the image, including any elements that contradict the desired style or tone. - Recommend a list of suitable models from the stable diffusion lineup that best match the style and content described in the detailed prompt. - - Other notes please refer to {placeholder} + - Other notes please refer to the following example: - The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: "sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors": + The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors"]: [Positive Prompt: , Negative Prompt: , @@ -74,130 +96,384 @@ def claude_template(initial_prompt: str, placeholder: str): [Positive Prompt: "visually appealing, high-quality image of a cute dog in a vibrant, cartoon style, adorable appearance, expressive eyes, friendly demeanor, colorful and lively, reminiscent of popular animation studios, artwork.", Negative Prompt: "realism, dark or dull colors, scary or aggressive dog depictions, overly simplistic, stick figure drawings, blurry or distorted images, inappropriate or NSFW content.", Recommended Model List: ["Stable-diffusion: LahCuteCartoonSDXL_alpha.safetensors", "Other model recommended..."]] - - {initial_prompt} - - """ - ) - # Pass in values to the input variables - prompt = sd_prompt.format(initial_prompt="a cute dog", placeholder="") - return prompt + + Current conversation: + + {history} + + + Here is the human's next reply: + + {input} + + + Assistant: + """) def get_llm_processed_prompts(initial_prompt): - # get the bedrock client - bedrock_client = get_bedrock_client() - - prompt = claude_template(initial_prompt, '') - prompt = "\n\nHuman:{}".format(prompt) + "\n\nAssistant:" - logger.debug("final prompt: {}".format(prompt)) - body = json.dumps({ - "prompt": prompt, - "temperature": 0.7, - "top_p": 1, - "top_k": 0, - "max_tokens_to_sample": 500, - "stop_sequences": ["\n\nHuman:"] - }) - # note v2 is not output chinese characters - modelId = "anthropic.claude-v2" - accept = "*/*" - contentType = "application/json" - response = bedrock_client.invoke_model( - body=body, modelId=modelId, accept=accept, contentType=contentType + cl_llm = get_bedrock_llm() + memory = ConversationBufferMemory() + conversation = ConversationChain( + llm=cl_llm, verbose=False, memory=memory ) - response_body = json.loads(response.get("body").read()) - raw_completion = response_body.get("completion").split('\n') - logger.info("raw_completion: {}".format(raw_completion)) - # TODO: extract positive prompt, negative prompt and model list from the raw_completion + conversation.prompt = sd_prompt + response = conversation.predict(input=initial_prompt) + logger.info("the first invoke: {}".format(response)) + # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) - logger.info("positive_prompt: {}".format(positive_prompt)) - logger.info("negative_prompt: {}".format(negative_prompt)) - logger.info("model_list: {}".format(model_list)) + """ + [Positive Prompt: visually appealing, high-quality image of a big, large, muscular horse with powerful body, majestic stance, flowing mane, detailed texture, vivid color, striking photography., + Negative Prompt: ugly, distorted, inappropriate or NSFW content, + Recommended Model List: ["sd_xl_base_1.0.safetensors"]] + """ + positive_prompt = response.split('Positive Prompt: ')[1].split('Negative Prompt: ')[0].strip() + negative_prompt = response.split('Negative Prompt: ')[1].split('Recommended Model List: ')[0].strip() + model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', '').split(',') + logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, model_list)) return positive_prompt, negative_prompt, model_list -def generate_image(endpoint_name: str, positive_prompt: str, negative_prompt: str, model: List[str]): - # Construct the API request (this is a placeholder) + +def generate_image(positive_prompts: str, negative_prompts: str, model: List[str]): + st.write("Generate Image Process:") + + # set progress bar for user experience + progess = 5 + bar = st.progress(progess) + + job = create_inference_job(model) + progess += 5 + bar.progress(progess) + + inference = job["inference"] + + upload_inference_job_api_params(inference["api_params_s3_upload_url"], positive_prompts, negative_prompts) + progess += 5 + bar.progress(progess) + + run_inference_job(inference["id"]) + progess += 5 + bar.progress(progess) + + while True: + status_response = get_inference_job(inference["id"]) + if progess < 90: + progess += 10 + bar.progress(progess) + if status_response['status'] == 'succeed': + image_url = get_inference_image_output(inference["id"])[0] + st.image(image_url, caption=positive_prompts, use_column_width=True) + break + elif status_response['status'] == 'failed': + st.error("Image generation failed.") + break + else: + time.sleep(1) + + bar.progress(100) + return inference["id"] + + +def get_inference_job(inference_id: str): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + job = requests.get(STATUS_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def get_inference_param_output(inference_id: str): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + job = requests.get(PARAM_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def get_inference_image_output(inference_id: str): headers = { "Content-Type": "application/json", "Accept": "application/json", 'x-api-key': API_KEY } + + job = requests.get(IMAGE_API_URL, headers=headers, params={"jobID": inference_id}) + + return job.json() + + +def create_inference_job(models: List[str]): + if len(models) == 0: + models = default_models + + if not set(models).issubset(set(support_model_list)): + models = default_models + st.warning("use default model {} because LLM recommend not in support list".format(models)) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + # todo use default api params body = { + "user_id": API_USERNAME, "task_type": "txt2img", "models": { - model + "Stable-diffusion": models, + "embeddings": [] }, - "sagemaker_endpoint_name": endpoint_name, - "prompt": positive_prompt, - "negative_prompt": negative_prompt, - "denoising_strength": 0.75 + "filters": { + "createAt": datetime.now().timestamp(), + "creator": "sd-webui" + } } - response = requests.post(COMMAND_API_URL + "inference-api/inference", headers = headers, json = body) - return response.json() + job = requests.post(GENERATE_API_URL, headers=headers, json=body) + user_models = [] + for model in job.json()['inference']['models']: + user_models.append(model['name'][0]) + st.write("use models: {}".format(user_models)) + return job.json() + -def check_image_status(inference_id: str): - """Check the status of the image generation.""" +def run_inference_job(inference_id: str): headers = { - 'Accept': 'application/json', + "Content-Type": "application/json", + "Accept": "application/json", 'x-api-key': API_KEY } - # TODO, the schema is not completed according to the API document - response = requests.get(GENERATE_API_URL, headers = headers) - return response.json() -def get_image_url(inference_id): - """Get the URL of the generated image.""" - response = requests.get(f"{IMAGE_API_URL}/{inference_id}") - return response.json() + job = requests.put(COMMAND_API_URL + 'inference/v2/' + inference_id + '/run', headers=headers) + + return job.json() + + +def upload_inference_job_api_params(s3_url, positive: str, negative: str): + # todo use default api params + api_params = { + "prompt": positive, + "negative_prompt": negative, + "styles": [], + "seed": -1, + "subseed": -1, + "subseed_strength": 0.0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "sampler_name": "DPM++ 2M Karras", + "batch_size": 1, + "n_iter": 1, + "steps": 20, + "cfg_scale": 7.0, + "width": 512, + "height": 512, + "restore_faces": None, + "tiling": None, + "do_not_save_samples": False, + "do_not_save_grid": False, + "eta": None, + "denoising_strength": None, + "s_min_uncond": 0.0, + "s_churn": 0.0, + "s_tmax": "Infinity", + "s_tmin": 0.0, + "s_noise": 1.0, + "override_settings": {}, + "override_settings_restore_afterwards": True, + "refiner_checkpoint": None, + "refiner_switch_at": None, + "disable_extra_networks": False, + "comments": {}, + "enable_hr": False, + "firstphase_width": 0, + "firstphase_height": 0, + "hr_scale": 2.0, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_checkpoint_name": None, + "hr_sampler_name": None, + "hr_prompt": "", + "hr_negative_prompt": "", + "sampler_index": "DPM++ 2M Karras", + "script_name": None, + "script_args": [], + "send_images": True, + "save_images": False, + "alwayson_scripts": { + "refiner": { + "args": [False, "", 0.8] + }, + "seed": { + "args": [-1, False, -1, 0, 0, 0] + }, + "controlnet": { + "args": [ + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + }, + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + }, + { + "enabled": False, + "module": "none", + "model": "None", + "weight": 1, + "image": None, + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": -1, + "threshold_a": -1, + "threshold_b": -1, + "guidance_start": 0, + "guidance_end": 1, + "pixel_perfect": False, + "control_mode": "Balanced", + "is_ui": True, + "input_mode": "simple", + "batch_images": "", + "output_dir": "", + "loopback": False + } + ] + }, + "segment anything": { + "args": [ + False, + False, + 0, + None, + [], + 0, + False, + [], + [], + False, + 0, + 1, + False, + False, + 0, + None, + [], + -2, + False, + [], + False, + 0, + None, + None + ] + }, + "extra options": { + "args": [] + } + } + } -def streamlit(): - # Streamlit layout - st.title("Image Generation Application") + json_string = json.dumps(api_params) + response = requests.put(s3_url, data=json_string) + response.raise_for_status() + return response - # User input - prompt = st.text_input("Enter a prompt for the image:", "A cute dog") - # Button to start the image generation process - if st.button('Generate Image'): +def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): + negative = "" + models = default_models + if llm_prompt is True: + st.write("Wait for LLM to process the prompt...") positive_prompt, negative_prompt, model_list = get_llm_processed_prompts(prompt) - # Assuming the first model is chosen for simplicity - # chosen_model = model_list.split('\n')[0] - - # Generate the detailed prompt - response = generate_image(positive_prompt, negative_prompt, model_list) - - # Display image (placeholder for actual image retrieval logic) - st.image("https://picsum.photos/200", caption=positive_prompt) - - if response.status_code == 200: - inference_id = response.json()['inference_id'] - # Check the status periodically - with st.empty(): - while True: - status_response = check_image_status(inference_id) - if status_response['status'] == 'succeeded': - image_url = get_image_url(inference_id)['url'] - st.image(image_url) - break - elif status_response['status'] == 'failed': - st.error("Image generation failed.") - break - else: - st.text("Waiting for the image to be generated...") - time.sleep(5) # Sleep for a while before checking the status again - else: - st.error("Failed to start the image generation process.") + + # if prompt is empty, use default + if positive_prompt != "": + initial_prompt = positive_prompt + st.write("positive_prompt:") + st.info("{}".format(positive_prompt)) + + if negative_prompt != "": + negative = negative_prompt + st.write("negative_prompt:") + st.info("{}".format(negative_prompt)) + + if len(model_list) > 0 and model_list[0] != "": + models = model_list + st.write("model_list:") + st.info("{}".format(model_list)) + + inference_id = generate_image(initial_prompt, negative, models) + + return inference_id + # main entry point for debugging +# python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": - # deploy_sagemaker_endpoint() - # upload_model() - positive_prompt, negative_prompt, model_list = get_llm_processed_prompts("a cute dog") - # The endpoint fixed for now, since the deploy_sagemaker_endpoint() won't return the endpoint name - response = generate_image("default-endpoint-for-llm-bot", positive_prompt, negative_prompt, model_list) - logger.info("generate image response: {}".format(response)) - - # python -m streamlit run image-generation.py --server.port 8088 - # streamlit() + try: + # Streamlit layout + st.title("Image Generation Application") + + # User input + prompt = st.text_input("Enter a prompt for the image:", "A cute dog") + + button_disabled = False + if not button_disabled: + if st.button('Generate Image'): + button_disabled = True + st.empty() + + st.subheader("Image without LLM") + generate_llm_image(prompt, False) + + st.subheader("Image with LLM") + generate_llm_image(prompt) + except Exception as e: + logger.exception(e) + st.error("Failed to start the image generation process.") + raise e diff --git a/source/panel/requirements.txt b/source/panel/requirements.txt index 04af6ac9..bd5208b4 100644 --- a/source/panel/requirements.txt +++ b/source/panel/requirements.txt @@ -1,6 +1,7 @@ -python-dotenv +python-dotenv PyPDF2 -streamlit -langchain +streamlit +langchain openai -tiktoken \ No newline at end of file +tiktoken +boto3 From c83dd09da870a2cdd623ddd292d5d627a38443cb Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 20:46:34 +0800 Subject: [PATCH 03/16] improved image generation logic --- source/panel/image_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 884066f5..4f96529b 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -43,7 +43,7 @@ default_models = ["v1-5-pruned-emaonly.safetensors"] - +# todo will update api def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, endpoint_name: str = "default-endpoint-for-llm-bot"): headers = { From 74873b97baa4a68efd59173b031721e3b2d81e75 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 20:50:10 +0800 Subject: [PATCH 04/16] reduce diff --- source/panel/image_generation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 4f96529b..4e4ced2b 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -57,21 +57,19 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in "endpoint_name": endpoint_name } # https://.execute-api..amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint - res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody) + res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody) logger.info("deploy_sagemaker_endpoint: {}".format(res.json())) - def upload_model(): pass - def get_bedrock_llm(): # specify the profile_name to call the bedrock api if needed bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION) - model_id = "anthropic.claude-v2" + modelId = "anthropic.claude-v2" cl_llm = Bedrock( - model_id=model_id, + model_id=modelId, client=bedrock_client, model_kwargs={"max_tokens_to_sample": 1000}, ) From 3b45608eb3092df70b33e60763520d85bcbb3feb Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 20:54:05 +0800 Subject: [PATCH 05/16] recover template --- source/panel/image_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 4e4ced2b..746a6a9a 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -83,7 +83,7 @@ def get_bedrock_llm(): - Recommend a list of suitable models from the stable diffusion lineup that best match the style and content described in the detailed prompt. - Other notes please refer to the following example: - The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors"]: + The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: "sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors": [Positive Prompt: , Negative Prompt: , From f571248cc4e995060c659a147b6d9923267e053c Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 20:54:05 +0800 Subject: [PATCH 06/16] recover template --- source/panel/image_generation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 4e4ced2b..dcb618d9 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -83,7 +83,7 @@ def get_bedrock_llm(): - Recommend a list of suitable models from the stable diffusion lineup that best match the style and content described in the detailed prompt. - Other notes please refer to the following example: - The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors"]: + The output should be a plain text in Python List format shown follows, no extra content added beside Positive Prompt, Negative Prompt and Recommended Model List. The model list can only be chosen from the fixed list: "sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", "x2AnimeFinal_gzku.safetensors": [Positive Prompt: , Negative Prompt: , @@ -116,15 +116,16 @@ def get_llm_processed_prompts(initial_prompt): ) conversation.prompt = sd_prompt - response = conversation.predict(input=initial_prompt) - logger.info("the first invoke: {}".format(response)) - # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) - """ + Example: [Positive Prompt: visually appealing, high-quality image of a big, large, muscular horse with powerful body, majestic stance, flowing mane, detailed texture, vivid color, striking photography., Negative Prompt: ugly, distorted, inappropriate or NSFW content, Recommended Model List: ["sd_xl_base_1.0.safetensors"]] """ + response = conversation.predict(input=initial_prompt) + logger.info("the first invoke: {}".format(response)) + # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) + positive_prompt = response.split('Positive Prompt: ')[1].split('Negative Prompt: ')[0].strip() negative_prompt = response.split('Negative Prompt: ')[1].split('Recommended Model List: ')[0].strip() model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', '').split(',') From d219dc81e12f9546b6a15c9fbb7e7a27db05af73 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 21:01:25 +0800 Subject: [PATCH 07/16] improved support list tip --- source/panel/image_generation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index dcb618d9..69a160d9 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -215,7 +215,8 @@ def create_inference_job(models: List[str]): if not set(models).issubset(set(support_model_list)): models = default_models - st.warning("use default model {} because LLM recommend not in support list".format(models)) + st.warning( + "use default model {}\nbecause LLM recommend not in support list:\n{}".format(models, support_model_list)) headers = { "Content-Type": "application/json", From a37c0171a5494ab637a9336b2ff7eb9bad4d952c Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 21:03:31 +0800 Subject: [PATCH 08/16] resort wait tip --- source/panel/image_generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 69a160d9..d4eb68a2 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -116,12 +116,16 @@ def get_llm_processed_prompts(initial_prompt): ) conversation.prompt = sd_prompt + """ Example: [Positive Prompt: visually appealing, high-quality image of a big, large, muscular horse with powerful body, majestic stance, flowing mane, detailed texture, vivid color, striking photography., Negative Prompt: ugly, distorted, inappropriate or NSFW content, Recommended Model List: ["sd_xl_base_1.0.safetensors"]] """ + + st.write("Wait for LLM to process the prompt...") + response = conversation.predict(input=initial_prompt) logger.info("the first invoke: {}".format(response)) # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) @@ -428,7 +432,7 @@ def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): negative = "" models = default_models if llm_prompt is True: - st.write("Wait for LLM to process the prompt...") + positive_prompt, negative_prompt, model_list = get_llm_processed_prompts(prompt) # if prompt is empty, use default From 655cefbe3f2c97a342768705c441898db18ba312 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 21:05:30 +0800 Subject: [PATCH 09/16] remove example because ui will display --- source/panel/image_generation.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index d4eb68a2..661ebbe1 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -117,13 +117,6 @@ def get_llm_processed_prompts(initial_prompt): conversation.prompt = sd_prompt - """ - Example: - [Positive Prompt: visually appealing, high-quality image of a big, large, muscular horse with powerful body, majestic stance, flowing mane, detailed texture, vivid color, striking photography., - Negative Prompt: ugly, distorted, inappropriate or NSFW content, - Recommended Model List: ["sd_xl_base_1.0.safetensors"]] - """ - st.write("Wait for LLM to process the prompt...") response = conversation.predict(input=initial_prompt) @@ -220,7 +213,7 @@ def create_inference_job(models: List[str]): if not set(models).issubset(set(support_model_list)): models = default_models st.warning( - "use default model {}\nbecause LLM recommend not in support list:\n{}".format(models, support_model_list)) + "Use default model {}\nbecause LLM recommend not in support list:\n{}".format(models, support_model_list)) headers = { "Content-Type": "application/json", From 3014f9f82584e7581c9a01abf7a0ba3b2b6296a0 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Fri, 17 Nov 2023 23:09:54 +0800 Subject: [PATCH 10/16] improved select checkpoints --- source/panel/image_generation.py | 82 +++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 661ebbe1..f9467d9e 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -5,42 +5,40 @@ from datetime import datetime from typing import List -from langchain.prompts import PromptTemplate -from langchain.llms.bedrock import Bedrock -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple -from langchain.docstore.document import Document -from langchain.chains import ConversationChain -from langchain.vectorstores import FAISS -from langchain.memory import ConversationBufferMemory - import boto3 import requests import streamlit as st from dotenv import load_dotenv from langchain import PromptTemplate - -# load .env file with specific name -load_dotenv(dotenv_path='.env_sd') +from langchain.chains import ConversationChain +from langchain.llms.bedrock import Bedrock +from langchain.memory import ConversationBufferMemory logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# load the URL from .env file -# 'https://xxxx.execute-api.us-west-2.amazonaws.com/prod/' +# load .env file with specific name +load_dotenv(dotenv_path='.env_sd') + +# Your ApiGatewayUrl in Extension for Stable Diffusion +# Example: https://xxxx.execute-api.us-west-2.amazonaws.com/prod/ COMMAND_API_URL = os.getenv("COMMON_API_URL") +# Your ApiGatewayUrlToken in Extension for Stable Diffusion API_KEY = os.getenv("API_KEY") +# Your username in Extension for Stable Diffusion +# Some resources are limited to specific users API_USERNAME = os.getenv("API_USERNAME") # The service support varies in different regions BEDROCK_REGION = os.getenv("BEDROCK_REGION") +# API URL GENERATE_API_URL = COMMAND_API_URL + "inference/v2" STATUS_API_URL = COMMAND_API_URL + "inference/get-inference-job" PARAM_API_URL = COMMAND_API_URL + "inference/get-inference-job-param-output" IMAGE_API_URL = COMMAND_API_URL + "inference/get-inference-job-image-output" +CHECKPOINTS_API_URL = COMMAND_API_URL + "checkpoints" -support_model_list = ["sd_xl_base_1.0.safetensors", "majicmixRealistic_v7.safetensors", - "x2AnimeFinal_gzku.safetensors", "v1-5-pruned-emaonly.safetensors"] - +support_model_list = [] default_models = ["v1-5-pruned-emaonly.safetensors"] # todo will update api @@ -75,6 +73,7 @@ def get_bedrock_llm(): ) return cl_llm +# todo template use dynamic checkpoints sd_prompt = PromptTemplate.from_template( """ Human: @@ -206,14 +205,37 @@ def get_inference_image_output(inference_id: str): return job.json() +def get_checkpoints(): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + 'x-api-key': API_KEY + } + + params = { + "username": API_USERNAME, + "status": "Active", + } + + job = requests.get(CHECKPOINTS_API_URL, headers=headers, params=params) + + checkpoints = [] + if 'checkpoints' in job.json(): + for checkpoint in job.json()['checkpoints']: + checkpoints.append(checkpoint['name'][0]) + + if len(checkpoints) == 0: + raise Exception("No checkpoint available.") + + global support_model_list + support_model_list = checkpoints + logger.info("support_model_list: {}".format(support_model_list)) + return support_model_list + + def create_inference_job(models: List[str]): - if len(models) == 0: - models = default_models - if not set(models).issubset(set(support_model_list)): - models = default_models - st.warning( - "Use default model {}\nbecause LLM recommend not in support list:\n{}".format(models, support_model_list)) + models = select_checkpoint(models) headers = { "Content-Type": "application/json", @@ -422,6 +444,7 @@ def upload_inference_job_api_params(s3_url, positive: str, negative: str): def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): + global support_model_list negative = "" models = default_models if llm_prompt is True: @@ -449,6 +472,18 @@ def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): return inference_id +def select_checkpoint(user_list: List[str]): + global support_model_list + user_list = [item.strip() for item in user_list] + intersection = list(set(user_list).intersection(set(support_model_list))) + if len(intersection) == 0: + intersection = default_models + st.warning( + "Use default model {}\nwhen LLM recommend not in support list:\n{}".format(intersection, + support_model_list)) + return intersection + + # main entry point for debugging # python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": @@ -462,6 +497,8 @@ def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): button_disabled = False if not button_disabled: if st.button('Generate Image'): + get_checkpoints() + button_disabled = True st.empty() @@ -472,5 +509,4 @@ def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): generate_llm_image(prompt) except Exception as e: logger.exception(e) - st.error("Failed to start the image generation process.") raise e From 7c523d596b4e28d3b6e97264c592bd3307503901 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sat, 18 Nov 2023 01:02:32 +0800 Subject: [PATCH 11/16] improved UI --- source/panel/image_generation.py | 106 +++++++++++++++++-------------- 1 file changed, 60 insertions(+), 46 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index f9467d9e..c0a6d68f 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -116,8 +116,6 @@ def get_llm_processed_prompts(initial_prompt): conversation.prompt = sd_prompt - st.write("Wait for LLM to process the prompt...") - response = conversation.predict(input=initial_prompt) logger.info("the first invoke: {}".format(response)) # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) @@ -129,43 +127,57 @@ def get_llm_processed_prompts(initial_prompt): return positive_prompt, negative_prompt, model_list -def generate_image(positive_prompts: str, negative_prompts: str, model: List[str]): - st.write("Generate Image Process:") - - # set progress bar for user experience - progess = 5 - bar = st.progress(progess) +def generate_image(positive_prompts: str, negative_prompts: str, model: List[str], current_col, progress_bar): job = create_inference_job(model) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) inference = job["inference"] upload_inference_job_api_params(inference["api_params_s3_upload_url"], positive_prompts, negative_prompts) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) run_inference_job(inference["id"]) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) while True: status_response = get_inference_job(inference["id"]) - if progess < 90: - progess += 10 - bar.progress(progess) + if st.session_state.progress < 90: + st.session_state.progress += 10 + progress_bar.progress(st.session_state.progress) if status_response['status'] == 'succeed': + progress_bar.progress(100) image_url = get_inference_image_output(inference["id"])[0] - st.image(image_url, caption=positive_prompts, use_column_width=True) + current_col.image(image_url, caption=positive_prompts, use_column_width=True) break elif status_response['status'] == 'failed': - st.error("Image generation failed.") + current_col.error("Image generation failed.") break else: time.sleep(1) - bar.progress(100) + for item in st.session_state.warning: + current_col.warning(item) + + api_params = get_inference_param_output(inference["id"]) + params = requests.get(api_params[0]).json() + info = json.loads(params['info']) + + if info["prompt"] != "": + current_col.write("prompt:") + current_col.info(info["prompt"]) + + if info["negative_prompt"] != "": + current_col.write("negative_prompt:") + current_col.info(info["negative_prompt"]) + + if info["sd_model_name"] != "": + current_col.write("sd_model_name:") + current_col.info(info["sd_model_name"]) + return inference["id"] @@ -258,10 +270,6 @@ def create_inference_job(models: List[str]): } job = requests.post(GENERATE_API_URL, headers=headers, json=body) - user_models = [] - for model in job.json()['inference']['models']: - user_models.append(model['name'][0]) - st.write("use models: {}".format(user_models)) return job.json() @@ -443,31 +451,33 @@ def upload_inference_job_api_params(s3_url, positive: str, negative: str): return response -def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): +def generate_llm_image(initial_prompt: str, llm_prompt: bool, col, title: str): + col.empty() + col.subheader(title) + st.spinner() + st.session_state.progress = 5 + progress_bar = col.progress(st.session_state.progress) + global support_model_list negative = "" models = default_models if llm_prompt is True: positive_prompt, negative_prompt, model_list = get_llm_processed_prompts(prompt) + st.session_state.progress += 15 + progress_bar.progress(st.session_state.progress) # if prompt is empty, use default if positive_prompt != "": initial_prompt = positive_prompt - st.write("positive_prompt:") - st.info("{}".format(positive_prompt)) if negative_prompt != "": negative = negative_prompt - st.write("negative_prompt:") - st.info("{}".format(negative_prompt)) if len(model_list) > 0 and model_list[0] != "": models = model_list - st.write("model_list:") - st.info("{}".format(model_list)) - inference_id = generate_image(initial_prompt, negative, models) + inference_id = generate_image(initial_prompt, negative, models, col, progress_bar) return inference_id @@ -478,9 +488,9 @@ def select_checkpoint(user_list: List[str]): intersection = list(set(user_list).intersection(set(support_model_list))) if len(intersection) == 0: intersection = default_models - st.warning( - "Use default model {}\nwhen LLM recommend not in support list:\n{}".format(intersection, - support_model_list)) + st.session_state.warning.append("Use default model {}\nwhen LLM recommend not in support list:\n{}".format( + intersection, support_model_list)) + return intersection @@ -488,25 +498,29 @@ def select_checkpoint(user_list: List[str]): # python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": try: - # Streamlit layout + st.set_page_config(layout="wide", page_title="Image Generation Application") + st.title("Image Generation Application") # User input prompt = st.text_input("Enter a prompt for the image:", "A cute dog") - button_disabled = False - if not button_disabled: - if st.button('Generate Image'): - get_checkpoints() + button = st.button('Generate Image') + + col1, col2 = st.columns(2) - button_disabled = True - st.empty() + # col1.subheader("Without LLM") + # col1.image(Image.open("./zebra.jpg")) + # + # col2.subheader("With LLM") + # col2.image(Image.open("./zebra.jpg")) - st.subheader("Image without LLM") - generate_llm_image(prompt, False) + if button: + get_checkpoints() + st.session_state.warning = [] + generate_llm_image(prompt, False, col1, "Without LLM") + generate_llm_image(prompt, True, col2, "With LLM") - st.subheader("Image with LLM") - generate_llm_image(prompt) except Exception as e: logger.exception(e) raise e From 0816ccc0a9d7989fc7db437e073de39769658b1d Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sat, 18 Nov 2023 01:02:32 +0800 Subject: [PATCH 12/16] improved UI --- source/panel/image_generation.py | 106 +++++++++++++++++-------------- 1 file changed, 60 insertions(+), 46 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index f9467d9e..8d91ec6c 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -116,8 +116,6 @@ def get_llm_processed_prompts(initial_prompt): conversation.prompt = sd_prompt - st.write("Wait for LLM to process the prompt...") - response = conversation.predict(input=initial_prompt) logger.info("the first invoke: {}".format(response)) # logger.info("the second invoke: {}".format(conversation.predict(input="change to realist style"))) @@ -129,43 +127,57 @@ def get_llm_processed_prompts(initial_prompt): return positive_prompt, negative_prompt, model_list -def generate_image(positive_prompts: str, negative_prompts: str, model: List[str]): - st.write("Generate Image Process:") - - # set progress bar for user experience - progess = 5 - bar = st.progress(progess) +def generate_image(positive_prompts: str, negative_prompts: str, model: List[str], current_col, progress_bar): job = create_inference_job(model) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) inference = job["inference"] upload_inference_job_api_params(inference["api_params_s3_upload_url"], positive_prompts, negative_prompts) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) run_inference_job(inference["id"]) - progess += 5 - bar.progress(progess) + st.session_state.progress += 5 + progress_bar.progress(st.session_state.progress) while True: status_response = get_inference_job(inference["id"]) - if progess < 90: - progess += 10 - bar.progress(progess) + if st.session_state.progress < 80: + st.session_state.progress += 10 + progress_bar.progress(st.session_state.progress) if status_response['status'] == 'succeed': + progress_bar.progress(100) image_url = get_inference_image_output(inference["id"])[0] - st.image(image_url, caption=positive_prompts, use_column_width=True) + current_col.image(image_url, caption=positive_prompts, use_column_width=True) break elif status_response['status'] == 'failed': - st.error("Image generation failed.") + current_col.error("Image generation failed.") break else: time.sleep(1) - bar.progress(100) + for item in st.session_state.warning: + current_col.warning(item) + + api_params = get_inference_param_output(inference["id"]) + params = requests.get(api_params[0]).json() + info = json.loads(params['info']) + + if info["prompt"] != "": + current_col.write("prompt:") + current_col.info(info["prompt"]) + + if info["negative_prompt"] != "": + current_col.write("negative_prompt:") + current_col.info(info["negative_prompt"]) + + if info["sd_model_name"] != "": + current_col.write("sd_model_name:") + current_col.info(info["sd_model_name"]) + return inference["id"] @@ -258,10 +270,6 @@ def create_inference_job(models: List[str]): } job = requests.post(GENERATE_API_URL, headers=headers, json=body) - user_models = [] - for model in job.json()['inference']['models']: - user_models.append(model['name'][0]) - st.write("use models: {}".format(user_models)) return job.json() @@ -443,31 +451,33 @@ def upload_inference_job_api_params(s3_url, positive: str, negative: str): return response -def generate_llm_image(initial_prompt: str, llm_prompt: bool = True): +def generate_llm_image(initial_prompt: str, llm_prompt: bool, col, title: str): + col.empty() + col.subheader(title) + st.spinner() + st.session_state.progress = 5 + progress_bar = col.progress(st.session_state.progress) + global support_model_list negative = "" models = default_models if llm_prompt is True: positive_prompt, negative_prompt, model_list = get_llm_processed_prompts(prompt) + st.session_state.progress += 15 + progress_bar.progress(st.session_state.progress) # if prompt is empty, use default if positive_prompt != "": initial_prompt = positive_prompt - st.write("positive_prompt:") - st.info("{}".format(positive_prompt)) if negative_prompt != "": negative = negative_prompt - st.write("negative_prompt:") - st.info("{}".format(negative_prompt)) if len(model_list) > 0 and model_list[0] != "": models = model_list - st.write("model_list:") - st.info("{}".format(model_list)) - inference_id = generate_image(initial_prompt, negative, models) + inference_id = generate_image(initial_prompt, negative, models, col, progress_bar) return inference_id @@ -478,9 +488,9 @@ def select_checkpoint(user_list: List[str]): intersection = list(set(user_list).intersection(set(support_model_list))) if len(intersection) == 0: intersection = default_models - st.warning( - "Use default model {}\nwhen LLM recommend not in support list:\n{}".format(intersection, - support_model_list)) + st.session_state.warning.append("Use default model {}\nwhen LLM recommends {} not in support list:\n{}".format( + default_models, user_list, support_model_list)) + return intersection @@ -488,25 +498,29 @@ def select_checkpoint(user_list: List[str]): # python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": try: - # Streamlit layout + st.set_page_config(layout="wide", page_title="Image Generation Application") + st.title("Image Generation Application") # User input prompt = st.text_input("Enter a prompt for the image:", "A cute dog") - button_disabled = False - if not button_disabled: - if st.button('Generate Image'): - get_checkpoints() + button = st.button('Generate Image') + + col1, col2 = st.columns(2) - button_disabled = True - st.empty() + # col1.subheader("Without LLM") + # col1.image(Image.open("./zebra.jpg")) + # + # col2.subheader("With LLM") + # col2.image(Image.open("./zebra.jpg")) - st.subheader("Image without LLM") - generate_llm_image(prompt, False) + if button: + get_checkpoints() + st.session_state.warning = [] + generate_llm_image(prompt, False, col1, "Without LLM") + generate_llm_image(prompt, True, col2, "With LLM") - st.subheader("Image with LLM") - generate_llm_image(prompt) except Exception as e: logger.exception(e) raise e From cd81ac70b1be22a75e30e6256bd46123e6d5d8e6 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sat, 18 Nov 2023 01:10:04 +0800 Subject: [PATCH 13/16] improved image layout --- source/panel/image_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 8d91ec6c..e6801ff0 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -151,7 +151,7 @@ def generate_image(positive_prompts: str, negative_prompts: str, model: List[str if status_response['status'] == 'succeed': progress_bar.progress(100) image_url = get_inference_image_output(inference["id"])[0] - current_col.image(image_url, caption=positive_prompts, use_column_width=True) + current_col.image(image_url, use_column_width=True) break elif status_response['status'] == 'failed': current_col.error("Image generation failed.") From c349880e73c24b628857d9a1320f50a1cca900e8 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sat, 18 Nov 2023 10:20:38 +0800 Subject: [PATCH 14/16] improved warnings --- source/panel/image_generation.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index e6801ff0..13381099 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -159,9 +159,6 @@ def generate_image(positive_prompts: str, negative_prompts: str, model: List[str else: time.sleep(1) - for item in st.session_state.warning: - current_col.warning(item) - api_params = get_inference_param_output(inference["id"]) params = requests.get(api_params[0]).json() info = json.loads(params['info']) @@ -178,6 +175,9 @@ def generate_image(positive_prompts: str, negative_prompts: str, model: List[str current_col.write("sd_model_name:") current_col.info(info["sd_model_name"]) + for warning in st.session_state.warnings: + current_col.warning(warning) + return inference["id"] @@ -488,7 +488,7 @@ def select_checkpoint(user_list: List[str]): intersection = list(set(user_list).intersection(set(support_model_list))) if len(intersection) == 0: intersection = default_models - st.session_state.warning.append("Use default model {}\nwhen LLM recommends {} not in support list:\n{}".format( + st.session_state.warnings.append("Use default model {}\nwhen LLM recommends {} not in support list:\n{}".format( default_models, user_list, support_model_list)) return intersection @@ -498,7 +498,7 @@ def select_checkpoint(user_list: List[str]): # python -m streamlit run image_generation.py --server.port 8088 if __name__ == "__main__": try: - st.set_page_config(layout="wide", page_title="Image Generation Application") + st.set_page_config(page_title="Image Generation Application") st.title("Image Generation Application") @@ -509,15 +509,9 @@ def select_checkpoint(user_list: List[str]): col1, col2 = st.columns(2) - # col1.subheader("Without LLM") - # col1.image(Image.open("./zebra.jpg")) - # - # col2.subheader("With LLM") - # col2.image(Image.open("./zebra.jpg")) - if button: get_checkpoints() - st.session_state.warning = [] + st.session_state.warnings = [] generate_llm_image(prompt, False, col1, "Without LLM") generate_llm_image(prompt, True, col2, "With LLM") From 0584159f6c8a6a1a7bc13e51b6b8189fbdac0c2d Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sat, 18 Nov 2023 10:28:52 +0800 Subject: [PATCH 15/16] fixed typos --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0aa299a9..e6c4b7d6 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ arn:aws:cloudformation:us-east-1:xx:stack/llm-bot-dev/xx Use Postman/cURL to test the API connection, the API endpoint is the output of CloudFormation Stack with prefix 'embedding' or 'llm', the sample URL will be like "https://xxxx.execute-api.us-east-1.amazonaws.com/v1/embedding", the API request body is as follows: -**Offline process to pre-process file specificed in S3 bucket and prefix, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/etl** +**Offline process to pre-process file specified in S3 bucket and prefix, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/etl** ```bash BODY { @@ -134,7 +134,7 @@ You should see output like this: } ``` -**Delete intial index in AOS, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/embedding for debugging purpose** +**Delete initial index in AOS, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/embedding for debugging purpose** ```bash { "aos_index": "chatbot-index", @@ -143,7 +143,7 @@ You should see output like this: } ``` -**Create intial index in AOS, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/embedding for debugging purpose** +**Create initial index in AOS, POST https://xxxx.execute-api.us-east-1.amazonaws.com/v1/embedding for debugging purpose** ```bash { "aos_index": "chatbot-index", From 2622161f00554b976f51b472cbaf7bedf56bb4d1 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Sun, 19 Nov 2023 16:02:24 +0800 Subject: [PATCH 16/16] improved warnings --- source/panel/image_generation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 13381099..4242984f 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -41,6 +41,7 @@ support_model_list = [] default_models = ["v1-5-pruned-emaonly.safetensors"] + # todo will update api def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, endpoint_name: str = "default-endpoint-for-llm-bot"): @@ -55,12 +56,14 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in "endpoint_name": endpoint_name } # https://.execute-api..amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint - res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody) + res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody) logger.info("deploy_sagemaker_endpoint: {}".format(res.json())) + def upload_model(): pass + def get_bedrock_llm(): # specify the profile_name to call the bedrock api if needed bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION) @@ -73,6 +76,7 @@ def get_bedrock_llm(): ) return cl_llm + # todo template use dynamic checkpoints sd_prompt = PromptTemplate.from_template( """ @@ -107,6 +111,7 @@ def get_bedrock_llm(): Assistant: """) + def get_llm_processed_prompts(initial_prompt): cl_llm = get_bedrock_llm() memory = ConversationBufferMemory() @@ -122,13 +127,15 @@ def get_llm_processed_prompts(initial_prompt): positive_prompt = response.split('Positive Prompt: ')[1].split('Negative Prompt: ')[0].strip() negative_prompt = response.split('Negative Prompt: ')[1].split('Recommended Model List: ')[0].strip() - model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', '').split(',') - logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, model_list)) + model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', + '').split( + ',') + logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, + model_list)) return positive_prompt, negative_prompt, model_list def generate_image(positive_prompts: str, negative_prompts: str, model: List[str], current_col, progress_bar): - job = create_inference_job(model) st.session_state.progress += 5 progress_bar.progress(st.session_state.progress) @@ -246,7 +253,6 @@ def get_checkpoints(): def create_inference_job(models: List[str]): - models = select_checkpoint(models) headers = {