Skip to content

Commit

Permalink
fix: fix chat and optimize deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
IcyKallen committed Jan 14, 2025
1 parent 812212d commit 0462682
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
4 changes: 2 additions & 2 deletions source/infrastructure/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"clobber": "npx projen clobber",
"compile": "npx projen compile",
"default": "npx projen default",
"deploy": "npx cdk deploy",
"deploy": "npx cdk deploy --all",
"destroy": "npx projen destroy",
"diff": "npx projen diff",
"eject": "npx projen eject",
Expand Down Expand Up @@ -100,4 +100,4 @@
}
},
"//": "~~ Generated by projen. To modify, edit .projenrc.js and run \"npx projen\"."
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
chat models build in command pattern
"""
from common_logic.common_utils.constant import LLMModelType,ModelProvider

from common_logic.common_utils.constant import LLMModelType, ModelProvider

from ..model_config import MODEL_CONFIGS


Expand All @@ -10,7 +12,7 @@ class ModeMixins:
def convert_messages_role(messages: list[dict], role_map: dict):
"""
Args:
messages (list[dict]):
messages (list[dict]):
role_map (dict): {"current_role":"targe_role"}
Returns:
Expand All @@ -20,18 +22,22 @@ def convert_messages_role(messages: list[dict], role_map: dict):
new_messages = []
for message in messages:
message = {**message}
role = message['role']
role = message["role"]
assert role in valid_roles, (role, valid_roles, messages)
message['role'] = role_map[role]
message["role"] = role_map[role]
new_messages.append(message)
return new_messages


class ModelMeta(type):
def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
if name == "Model" or new_cls.model_id is None \
or name.endswith("BaseModel") or name.lower().endswith("basemodel"):
if (
name == "Model"
or new_cls.model_id is None
or name.endswith("BaseModel")
or name.lower().endswith("basemodel")
):
return new_cls
new_cls.model_map[new_cls.get_model_id()] = new_cls
return new_cls
Expand All @@ -48,38 +54,45 @@ class Model(ModeMixins, metaclass=ModelMeta):
@classmethod
def create_model(cls, model_kwargs=None, **kwargs):
raise NotImplementedError

@classmethod
def get_model_id(cls):
return f"{cls.model_id}__{cls.model_provider}"
def get_model_id(cls, model_id=None, model_provider=None):
if model_id is None:
model_id = cls.model_id
if model_provider is None:
model_provider = cls.model_provider
return f"{model_id}__{model_provider}"

@classmethod
def get_model(cls, model_id, model_kwargs=None, **kwargs):
model_provider = kwargs["provider"]
# dynamic load module
_load_module(model_id)
return cls.model_map[cls.get_model_id()].create_model(model_kwargs=model_kwargs, **kwargs)
return cls.model_map[cls.get_model_id(model_id=model_id, model_provider=model_provider)].create_model(
model_kwargs=model_kwargs, **kwargs
)

@classmethod
def model_id_to_class_name(cls, model_id: str) -> str:
"""Convert model ID to a valid Python class name.
Examples:
anthropic.claude-3-haiku-20240307-v1:0 -> Claude3Haiku20240307V1Model
"""
# Remove version numbers and vendor prefixes
name = str(model_id).split(':')[0]
name = name.split('.')[-1]
parts = name.replace('_', '-').split('-')
name = str(model_id).split(":")[0]
name = name.split(".")[-1]
parts = name.replace("_", "-").split("-")

cleaned_parts = []
for part in parts:
if any(c.isdigit() for c in part):
cleaned = ''.join(c.upper() if i == 0 or part[i-1] in '- ' else c
for i, c in enumerate(part))
cleaned = "".join(c.upper() if i == 0 or part[i - 1] in "- " else c for i, c in enumerate(part))
else:
cleaned = part.capitalize()
cleaned_parts.append(cleaned)

return ''.join(cleaned_parts) + "Model"
return "".join(cleaned_parts) + "Model"

@classmethod
def create_for_model(cls, model_id: str):
Expand All @@ -95,13 +108,14 @@ def create_for_model(cls, model_id: str):
"default_model_kwargs": config.default_model_kwargs,
"enable_any_tool_choice": config.enable_any_tool_choice,
"enable_prefill": config.enable_prefill,
}
},
)
return model_class


def _import_bedrock_models():
from .bedrock_models import model_classes

# from .bedrock_models import (
# Claude2,
# ClaudeInstance,
Expand All @@ -119,20 +133,15 @@ def _import_bedrock_models():


def _import_openai_models():
from .openai_models import (
ChatGPT35,
ChatGPT4Turbo,
ChatGPT4o
)
from .openai_models import ChatGPT4o, ChatGPT4Turbo, ChatGPT35


def _import_dmaa_models():
from . import dmaa_models



def _load_module(model_id):
assert model_id in MODEL_MODULE_LOAD_FN_MAP, (
model_id, MODEL_MODULE_LOAD_FN_MAP)
assert model_id in MODEL_MODULE_LOAD_FN_MAP, (model_id, MODEL_MODULE_LOAD_FN_MAP)
MODEL_MODULE_LOAD_FN_MAP[model_id]()


Expand Down Expand Up @@ -167,5 +176,5 @@ def _load_module(model_id):
LLMModelType.CLAUDE_3_5_SONNET_APAC: _import_bedrock_models,
LLMModelType.CLAUDE_3_HAIKU_APAC: _import_bedrock_models,
LLMModelType.LLAMA3_1_70B_INSTRUCT_US: _import_bedrock_models,
LLMModelType.QWEN25_INSTRUCT_72B_AWQ: _import_dmaa_models
LLMModelType.QWEN25_INSTRUCT_72B_AWQ: _import_dmaa_models,
}

0 comments on commit 0462682

Please sign in to comment.