Skip to content

Commit

Permalink
Merge pull request #23 from Snowflake-Labs/wy/swiftkv-refactor
Browse files Browse the repository at this point in the history
Wy/swiftkv refactor
  • Loading branch information
sfc-gh-yewang authored Dec 3, 2024
2 parents 90c8b9c + fb9cfdf commit 304cfb4
Show file tree
Hide file tree
Showing 8 changed files with 541 additions and 8 deletions.
7 changes: 7 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,13 @@ def maybe_create_spec_config(
)

draft_hf_config = draft_model_config.hf_config

if enable_chunked_prefill and \
not draft_hf_config.model_type in 'mlp_speculator':

raise ValueError(
"Speculative decoding and chunked prefill are currently "
f"mutually exclusive ({enable_chunked_prefill=}).")

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.models.interfaces import supports_lora_exemption_for_speculator
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -119,6 +120,10 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if supports_lora_exemption_for_speculator(model_class):
logger.warning(f"Model {model_class} does not support LoRA and"
"speculator will be turned off dynamically if input request"
"requires LoRA. ")
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
return model_class(**kwargs)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
supports_multimodal, supports_pp,
supports_lora_exemption_for_speculator)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
Expand All @@ -18,6 +19,7 @@
"supports_lora",
"SupportsMultiModal",
"supports_multimodal",
"supports_lora_exemption_for_speculator",
"SupportsPP",
"supports_pp",
]
6 changes: 6 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:

return supports_kw(model_forward, "intermediate_tensors")

@runtime_checkable
class LoRAExemptionForSpeculator(Protocol):
lora_exemption: ClassVar[Literal[True]] = True

def supports_lora_exemption_for_speculator(model: object) -> TypeIs[LoRAExemptionForSpeculator]:
return isinstance(model, LoRAExemptionForSpeculator)

@runtime_checkable
class HasInnerState(Protocol):
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .interfaces import LoRAExemptionForSpeculator

SQRT2 = 2**0.5

Expand Down Expand Up @@ -54,7 +55,7 @@ def forward(self, x):
return x


class MLPSpeculator(nn.Module):
class MLPSpeculator(nn.Module, LoRAExemptionForSpeculator):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Expand Down
Loading

0 comments on commit 304cfb4

Please sign in to comment.