Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched inference #214

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,3 @@ exo supports the following inference engines:
- ✅ [GRPC](exo/networking/grpc)
- 🚧 [Radio](TODO)
- 🚧 [Bluetooth](TODO)

56 changes: 48 additions & 8 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, request_id: str, timestamp: int, prompt: str):


class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, max_batch_size: int = 4, max_batch_wait: int = 10, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
Expand Down Expand Up @@ -183,6 +183,45 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
# Add middleware to log every request
self.app.middlewares.append(self.log_request)

self.queue = asyncio.Queue(maxsize=100)
self.app.on_startup.append(self.on_startup)
self.app.on_cleanup.append(self.on_cleanup)

self.max_batch_size = max_batch_size
self.max_batch_wait = max_batch_wait

async def on_startup(self, _):
self.processor_future = asyncio.ensure_future(self.batch_processor())

async def on_cleanup(self, _):
self.processor_future.cancel()
await self.processor_future

async def batch_processor(self):
while True:
shard = None
prompts, image_strs, request_ids, futures = [], [], [], []
last_infer = time.time()
while len(prompts) < self.max_batch_size and (time.time()-last_infer) < self.max_batch_wait:
try:
shard, prompt, image_str, request_id, future = self.queue.get_nowait() # TODO: batch according to shard
self.queue.task_done()
prompts.append(prompt)
image_strs.append(image_str)
request_ids.append(request_id)
futures.append(future)
except asyncio.QueueEmpty:
await asyncio.sleep(0.1)

if not prompts:
await asyncio.sleep(1)
continue

image_strs = image_strs if any(image_strs) else None
await self.node.process_prompt(shard, prompts, image_strs, request_ids=request_ids)
for future in futures:
future.set_result(None)

async def log_request(self, app, handler):
async def middleware(request):
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
Expand Down Expand Up @@ -249,7 +288,9 @@ async def handle_post_chat_completions(self, request):

if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
try:
await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
future = asyncio.Future()
await self.queue.put((shard, prompt, image_str, request_id, future))
await future
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
Expand All @@ -273,10 +314,9 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and eos_token_id in new_tokens:
new_tokens = new_tokens[:new_tokens.index(eos_token_id)]
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
Expand Down Expand Up @@ -322,8 +362,8 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
if eos_token_id in tokens:
tokens = tokens[:tokens.index(eos_token_id)]
finish_reason = "stop"

return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
Expand Down
2 changes: 1 addition & 1 deletion exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,4 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
else:
shard_specific_patterns = ["*.safetensors"]
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
return list(default_patterns | shard_specific_patterns)
22 changes: 15 additions & 7 deletions exo/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import asyncio
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List, Deque
import socket
import random
import platform
Expand All @@ -9,6 +9,7 @@
import netifaces
from pathlib import Path
import tempfile
from collections import deque

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand Down Expand Up @@ -94,20 +95,27 @@ def terminal_link(uri, label=None):
class AsyncCallback(Generic[T]):
def __init__(self) -> None:
self.condition: asyncio.Condition = asyncio.Condition()
self.result: Optional[Tuple[T, ...]] = None
self.result: Deque[Tuple[T, ...]] = deque()
self.observers: list[Callable[..., None]] = []

async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
async with self.condition:
await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
assert self.result is not None # for type checking
return self.result
async def wait_for_valid_result():
while True:
while self.result:
if self.result[0] and check_condition(*self.result[0]):
return True
self.result.popleft()
await self.condition.wait()

await asyncio.wait_for(wait_for_valid_result(), timeout)
return self.result.popleft()

def on_next(self, callback: Callable[..., None]) -> None:
self.observers.append(callback)

def set(self, *args: T) -> None:
self.result = args
self.result.append(args)
for observer in self.observers:
observer(*args)
asyncio.create_task(self.notify())
Expand Down Expand Up @@ -233,4 +241,4 @@ def get_all_ip_addresses():
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return ["localhost"]
return ["localhost"]
6 changes: 3 additions & 3 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import numpy as np
import os

from typing import Tuple, Optional
from typing import Tuple, Optional, List
from abc import ABC, abstractmethod
from .shard import Shard


class InferenceEngine(ABC):
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
async def infer_prompt(self, request_ids: List[str], shard: Shard, prompts: List[str], image_strs: Optional[List[str]] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass

@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
async def infer_tensor(self, request_ids: List[str], shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass


Expand Down
211 changes: 116 additions & 95 deletions exo/inference/mlx/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,98 +488,119 @@ def __call__(self, x: mx.array) -> mx.array:


class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.model_type = config.model_type
if config.vision_config:
self.vision_tower = VisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
self.language_model = LanguageModel(config.text_config, config.shard)

def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)

# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)

# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]

if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError("Unexpected feature selection strategy: "
f"{self.vision_feature_select_strategy}")

# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(selected_image_feature)

# Insert special image tokens in the input_ids
final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
return final_inputs_embeds

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()

if len(image_positions) != num_images:
raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images}).")

text_segments = []
start_idx = 0

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)

def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
input_embddings = None
if pixel_values is not None:
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
return logits

def sanitize(self, weights):
if self.config.vision_config:
weights = self.vision_tower.sanitize(weights)
else:
weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
weights = self.language_model.sanitize(weights)
return weights

@property
def layers(self):
return self.language_model.model.layers

@property
def head_dim(self):
return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)

@property
def n_kv_heads(self):
return self.language_model.model.num_key_value_heads
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.model_type = config.model_type
if config.vision_config:
self.vision_tower = VisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy
self.language_model = LanguageModel(config.text_config, config.shard)

def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)

# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
)

# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]

if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
"Unexpected feature selection strategy: "
f"{self.vision_feature_select_strategy}"
)

# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(selected_image_feature)

# Insert special image tokens in the input_ids
final_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
return final_inputs_embeds

def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_features_list = mx.split(image_features, image_features.shape[0])
inputs_embeds_list = mx.split(inputs_embeds, inputs_embeds.shape[0])
input_ids_list = mx.split(input_ids, input_ids.shape[0])

stack = []
for input_ids, inputs_embeds, image_features in zip(input_ids_list, inputs_embeds_list, image_features_list):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()

if len(image_positions) != num_images:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
)

text_segments = []
start_idx = 0

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
stack.append(mx.concatenate(final_embeddings, axis=1))
return mx.concatenate(stack, axis=0)

def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
input_embddings = None
if pixel_values is not None:
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embddings
)
return logits

def sanitize(self, weights):
if self.config.vision_config:
weights = self.vision_tower.sanitize(weights)
else:
weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
weights = self.language_model.sanitize(weights)
return weights

@property
def layers(self):
return self.language_model.model.layers

@property
def head_dim(self):
return (
self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
)

@property
def n_kv_heads(self):
return self.language_model.model.num_key_value_heads
Loading