diff --git a/source/panel/image_generation.py b/source/panel/image_generation.py index 13381099..4242984f 100644 --- a/source/panel/image_generation.py +++ b/source/panel/image_generation.py @@ -41,6 +41,7 @@ support_model_list = [] default_models = ["v1-5-pruned-emaonly.safetensors"] + # todo will update api def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1, endpoint_name: str = "default-endpoint-for-llm-bot"): @@ -55,12 +56,14 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in "endpoint_name": endpoint_name } # https://.execute-api..amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint - res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody) + res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody) logger.info("deploy_sagemaker_endpoint: {}".format(res.json())) + def upload_model(): pass + def get_bedrock_llm(): # specify the profile_name to call the bedrock api if needed bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION) @@ -73,6 +76,7 @@ def get_bedrock_llm(): ) return cl_llm + # todo template use dynamic checkpoints sd_prompt = PromptTemplate.from_template( """ @@ -107,6 +111,7 @@ def get_bedrock_llm(): Assistant: """) + def get_llm_processed_prompts(initial_prompt): cl_llm = get_bedrock_llm() memory = ConversationBufferMemory() @@ -122,13 +127,15 @@ def get_llm_processed_prompts(initial_prompt): positive_prompt = response.split('Positive Prompt: ')[1].split('Negative Prompt: ')[0].strip() negative_prompt = response.split('Negative Prompt: ')[1].split('Recommended Model List: ')[0].strip() - model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', '').split(',') - logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, model_list)) + model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', + '').split( + ',') + logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, + model_list)) return positive_prompt, negative_prompt, model_list def generate_image(positive_prompts: str, negative_prompts: str, model: List[str], current_col, progress_bar): - job = create_inference_job(model) st.session_state.progress += 5 progress_bar.progress(st.session_state.progress) @@ -246,7 +253,6 @@ def get_checkpoints(): def create_inference_job(models: List[str]): - models = select_checkpoint(models) headers = {