Skip to content

Commit

Permalink
increase max line length to 200
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 28, 2024
1 parent d94e3f9 commit 4cb36a7
Show file tree
Hide file tree
Showing 22 changed files with 100 additions and 328 deletions.
86 changes: 20 additions & 66 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,28 @@
shard_mappings = {
### llama
"llama-3.1-8b": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3.1-70b": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
},
"llama-3.1-405b": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126
),
"MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
},
"llama-3-8b": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
},
"llama-3-70b": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
),
"TinygradDynamicShardInferenceEngine": Shard(
model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
},
### mistral
"mistral-nemo": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
},
"mistral-large": {
"MLXDynamicShardInferenceEngine": Shard(
model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88
),
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
},
}

Expand All @@ -76,9 +60,7 @@ def resolve_tinygrad_tokenizer(model_id: str):
elif model_id == "llama3-70b-sfr":
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
else:
raise ValueError(
f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}"
)
raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")


async def resolve_tokenizer(model_id: str):
Expand Down Expand Up @@ -184,12 +166,8 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
allow_headers="*",
allow_methods="*",
)
cors.add(
self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}
)
cors.add(
self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}
)
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
Expand Down Expand Up @@ -220,22 +198,16 @@ async def handle_post_chat_completions(self, request):
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
chat_request = parse_chat_request(data)
if chat_request.model and chat_request.model.startswith(
"gpt-"
): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
chat_request.model = "llama-3.1-8b"
if not chat_request.model or chat_request.model not in shard_mappings:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
chat_request.model = "llama-3.1-8b"
shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
if not shard:
supported_models = [
model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines
]
supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
return web.json_response(
{
"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"
},
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
status=400,
)
request_id = str(uuid.uuid4())
Expand All @@ -255,9 +227,7 @@ async def handle_post_chat_completions(self, request):
import traceback

traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500
)
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)

try:
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
Expand All @@ -278,11 +248,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = (
tokenizer.special_tokens_map.get("eos_token_id")
if isinstance(tokenizer._tokenizer, AutoTokenizer)
else tokenizer.eos_token_id
)
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
Expand All @@ -309,9 +275,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
return _request_id == request_id and is_finished

_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
if (
request_id in self.stream_tasks
): # in case there is still a stream task running, wait for it to complete
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
Expand All @@ -326,21 +290,13 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
)

finish_reason = "length"
eos_token_id = (
tokenizer.special_tokens_map.get("eos_token_id")
if isinstance(tokenizer._tokenizer, AutoTokenizer)
else tokenizer.eos_token_id
)
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
finish_reason = "stop"

return web.json_response(
generate_completion(
chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"
)
)
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
except asyncio.TimeoutError:
return web.json_response({"detail": "Response generation timed out"}, status=408)
finally:
Expand All @@ -353,7 +309,5 @@ async def run(self, host: str = "0.0.0.0", port: int = 8000):
site = web.TCPSite(runner, host, port)
await site.start()
if DEBUG >= 0:
print(
f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}"
)
print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")
12 changes: 3 additions & 9 deletions exo/inference/debug_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,22 @@


# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(
inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
):
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
from exo.inference.tinygrad.inference import Tokenizer
from pathlib import Path

_tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))

prompt = "In a single word only, what is the last name of the president of the United States? "
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
)
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=resp_full,
inference_state=inference_state_full,
)

resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
"B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
)
resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
Expand Down
8 changes: 2 additions & 6 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@

class InferenceEngine(ABC):
@abstractmethod
async def infer_tensor(
self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
) -> Tuple[np.ndarray, str, bool]:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass

@abstractmethod
async def infer_prompt(
self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
) -> Tuple[np.ndarray, str, bool]:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass
10 changes: 2 additions & 8 deletions exo/inference/mlx/models/sharded_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,13 @@ def compute_llama3_base_freq(self):
return new_base_freqs.mean().item()

def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}, "
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
)
return f"{self.dims}, traditional={self.traditional}, " f"max_position_embeddings={self.max_position_embeddings}, " f"scaling_factor={self.scale}, rope_type={self.rope_type}"

def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
base = self.base
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
base *= ((self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)) ** (
self.dims / (self.dims - 2)
)
base *= ((self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)) ** (self.dims / (self.dims - 2))

return mx.fast.rope(
x,
Expand Down
12 changes: 3 additions & 9 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self):
self.shard = None

async def infer_prompt(
self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
output_data: np.ndarray = np.array(
self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))
)
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

async def infer_tensor(
self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
) -> (np.ndarray, str, bool):
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
await self.ensure_shard(shard)
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
Expand Down
6 changes: 1 addition & 5 deletions exo/inference/mlx/sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,5 @@ def __call__(
return self.step(x, temp, top_p, logit_bias)

def init_cache(self, request_id: str):
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
8 changes: 1 addition & 7 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,7 @@ def load_model_shard(
for wf in weight_files:
weights_dict = mx.load(wf)
all_weights_keys.update(weights_dict.keys())
weights.update(
{
k: v
for k, v in weights_dict.items()
if not k.startswith("model.layers.") or shard.start_layer <= int(k.split(".")[2]) <= shard.end_layer
}
)
weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split(".")[2]) <= shard.end_layer})

model_class, model_args_class = _get_classes(config=config)

Expand Down
12 changes: 3 additions & 9 deletions exo/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,17 @@


# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(
inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
):
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
"A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
)
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=resp_full,
inference_state=inference_state_full,
)

resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
"B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
)
resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
Expand Down
Loading

0 comments on commit 4cb36a7

Please sign in to comment.