Skip to content

Commit

Permalink
Dynamic shapes works too
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep committed Jan 17, 2025
1 parent 8b2e26b commit 2d6fba4
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 37 deletions.
8 changes: 0 additions & 8 deletions vllm/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ def forward(
self.past_key_value_states = past_key_value_states

# mark dynamic
'''
if self.past_key_value_states is not None:
for layer in self.past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)
'''

# removing batch padding sequences to compute logits
batch_size = input_ids.shape[0]
Expand Down Expand Up @@ -192,12 +190,6 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
f"accommodate prompt size of {max_prompt_length} and "
f"decode tokens of {max_decode_length}")

if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.dynamic_shapes = False
torch._dynamo.config.automatic_dynamic_shapes = False


if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
self.model = torch.compile(self.model,
mode=compile_mode,
Expand Down
29 changes: 0 additions & 29 deletions vllm/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,10 @@ def _warmup_model_forward_pass(self, warmup_tokens_tensor,
self.model_runner._update_mask()
self.model_runner._update_position_ids()

'''
if past_key_value_states is not None:
for layer in past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)
'''

logits, past_key_value_states = self.model_runner.\
_raw_model_forward(
Expand Down Expand Up @@ -301,33 +299,6 @@ def initialize_cache(self, num_gpu_blocks: int,
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

'''
# TODO: why not inference mode?
#@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
torch.set_grad_enabled(False)
if execute_model_req is None:
return None
finished_requests_ids = execute_model_req.finished_requests_ids
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
finished_requests_ids)
# Spyre worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return [output]
'''

def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
Expand Down

0 comments on commit 2d6fba4

Please sign in to comment.