diff --git a/vllm/worker/spyre_model_runner.py b/vllm/worker/spyre_model_runner.py index 92bda786..24227165 100644 --- a/vllm/worker/spyre_model_runner.py +++ b/vllm/worker/spyre_model_runner.py @@ -52,102 +52,6 @@ def from_broadcasted_tensor_dict( tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) return cls(**tensor_dict) -''' -class ModelInputForSpyreBuilder(ModelRunnerInputBuilderBase[ModelInputForSpyre]): - - def __init__(self, - runner: "SpyreModelRunner", - finished_requests_ids: Optional[List[str]] = None) -> None: - super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - self.seq_group_metadata_list.append(seq_group_metadata) - - def build(self) -> ModelInputForSpyre: - is_prompt = seq_group_metadata_list[0].is_prompt - if is_prompt: - (input_tokens, input_positions, input_masks, - _) = self._prepare_prompt(self.seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_masks) = self._prepare_decode(self.seq_group_metadata_list) - - return ModelInputForSpyre(input_tokens=input_tokens, - input_positions=input_positions, - input_masks=input_masks, - is_prompt=is_prompt) - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: - assert len(seq_group_metadata_list) > 0 - input_token_list: List[torch.Tensor] = [] - - # find warmup shape to be used for padding and batching - applicable_spyre_warmup_shapes = [ - shape for shape in self.scheduler_config.spyre_warmup_shapes - if len(seq_group_metadata_list) <= shape['batch_size'] - ] - for seq_group_metadata in seq_group_metadata_list: - seq_data = seq_group_metadata.seq_data[list( - seq_group_metadata.seq_data.keys())[0]] - # retrieve initial (unpadded) tokens - prompt_tokens = seq_data.get_token_ids() - new_tokens = seq_group_metadata.sampling_params.max_tokens\ - if seq_group_metadata.sampling_params is not None else 0 - - updated_spyre_warmup_shapes = [ - shape for shape in applicable_spyre_warmup_shapes - if len(prompt_tokens) <= shape['prompt_length'] - and new_tokens <= shape['new_tokens'] - ] - applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes - - assert applicable_spyre_warmup_shapes - - # If multiple warmup shapes apply, the first one is selected. - # For improving performance, the warmup shapes in scheduler_config - # are ordered by "processing speed". - min_pad_length_batch = applicable_spyre_warmup_shapes[0][ - 'prompt_length'] - padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size'] - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - # retrieve initial (unpadded) tokens - prompt_tokens = seq_data.get_token_ids() - - input_token_list.append( - torch.tensor(prompt_tokens, - dtype=torch.long, - device=torch.device("cpu"))) - - # set number of added padding sequences used for computing logits - self.model.num_padded_sequences = padded_batch_size - len( - input_token_list) - - # padding to compiled batch size - while len(input_token_list) < padded_batch_size: - input_token_list.append( - torch.zeros(min_pad_length_batch, - dtype=torch.long, - device=torch.device("cpu"))) - - # get position ids and attention mask - input_tokens, self._position_ids, self._mask = self.pad_input_ids( - input_token_list, min_pad_length=min_pad_length_batch) - - return input_tokens, self._position_ids, self._mask -''' class SpyreModelRunner(ModelRunnerBase):