Skip to content

Commit

Permalink
more robust message parsing fixes #81
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 25, 2024
1 parent 9420125 commit 03fe7a0
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
"MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
},
"llama-3-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
Expand Down Expand Up @@ -124,6 +124,17 @@ def build_prompt(tokenizer, messages: List[Message]):
messages, tokenize=False, add_generation_prompt=True
)

def parse_message(data: dict):
if 'role' not in data or 'content' not in data:
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
return Message(data['role'], data['content'])

def parse_chat_request(data: dict):
return ChatCompletionRequest(
data.get('model', 'llama-3.1-8b'),
[parse_message(msg) for msg in data['messages']],
data.get('temperature', 0.0)
)

class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
Expand Down Expand Up @@ -156,15 +167,14 @@ async def handle_root(self, request):
async def handle_post_chat_token_encode(self, request):
data = await request.json()
shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
messages = data.get('messages', [])
messages = [parse_message(msg) for msg in data.get('messages', [])]
tokenizer = await resolve_tokenizer(shard.model_id)
return web.json_response({'length': len(build_prompt(tokenizer, messages))})

async def handle_post_chat_completions(self, request):
data = await request.json()
stream = data.get('stream', False)
messages = [Message(**msg) for msg in data['messages']]
chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
chat_request = parse_chat_request(data)
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
chat_request.model = "llama-3.1-8b"
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
Expand All @@ -175,7 +185,7 @@ async def handle_post_chat_completions(self, request):
tokenizer = await resolve_tokenizer(shard.model_id)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, messages)
prompt = build_prompt(tokenizer, chat_request.messages)
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)

Expand Down

0 comments on commit 03fe7a0

Please sign in to comment.