Skip to content

Commit

Permalink
improved warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingyi committed Nov 19, 2023
1 parent 0584159 commit 2622161
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions source/panel/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
support_model_list = []
default_models = ["v1-5-pruned-emaonly.safetensors"]


# todo will update api
def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_instance_count: int = 1,
endpoint_name: str = "default-endpoint-for-llm-bot"):
Expand All @@ -55,12 +56,14 @@ def deploy_sagemaker_endpoint(instance_type: str = "ml.g4dn.4xlarge", initial_in
"endpoint_name": endpoint_name
}
# https://<Your API Gateway ID>.execute-api.<Your AWS Account Region>.amazonaws.com/{basePath}/inference/deploy-sagemaker-endpoint
res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers = headers, json = inputBody)
res = requests.post(COMMAND_API_URL + 'inference/deploy-sagemaker-endpoint', headers=headers, json=inputBody)
logger.info("deploy_sagemaker_endpoint: {}".format(res.json()))


def upload_model():
pass


def get_bedrock_llm():
# specify the profile_name to call the bedrock api if needed
bedrock_client = boto3.client('bedrock-runtime', region_name=BEDROCK_REGION)
Expand All @@ -73,6 +76,7 @@ def get_bedrock_llm():
)
return cl_llm


# todo template use dynamic checkpoints
sd_prompt = PromptTemplate.from_template(
"""
Expand Down Expand Up @@ -107,6 +111,7 @@ def get_bedrock_llm():
Assistant:
""")


def get_llm_processed_prompts(initial_prompt):
cl_llm = get_bedrock_llm()
memory = ConversationBufferMemory()
Expand All @@ -122,13 +127,15 @@ def get_llm_processed_prompts(initial_prompt):

positive_prompt = response.split('Positive Prompt: ')[1].split('Negative Prompt: ')[0].strip()
negative_prompt = response.split('Negative Prompt: ')[1].split('Recommended Model List: ')[0].strip()
model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"', '').split(',')
logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt, model_list))
model_list = response.split('Recommended Model List: ')[1].strip().replace('[', '').replace(']', '').replace('"',
'').split(
',')
logger.info("positive_prompt: {}\n negative_prompt: {}\n model_list: {}".format(positive_prompt, negative_prompt,
model_list))
return positive_prompt, negative_prompt, model_list


def generate_image(positive_prompts: str, negative_prompts: str, model: List[str], current_col, progress_bar):

job = create_inference_job(model)
st.session_state.progress += 5
progress_bar.progress(st.session_state.progress)
Expand Down Expand Up @@ -246,7 +253,6 @@ def get_checkpoints():


def create_inference_job(models: List[str]):

models = select_checkpoint(models)

headers = {
Expand Down

0 comments on commit 2622161

Please sign in to comment.