-
-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V0][Fix] structured decoding compatibility with speculative decoding #13823
base: main
Are you sure you want to change the base?
[V0][Fix] structured decoding compatibility with speculative decoding #13823
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Thanks for the contribution! Please can you make sure that your commits are signed off (instructions here). Also, some of the entrypoint tests are failing with:
Which appears to be relevant. |
Oh this is interesting, I did not consider the need for rollback until this case. Thanks for your work. I think it is crucial to add a test using guided decoding and speculation together since AFAIK we haven't used these together |
I will look into this td for the v1. But thanks for making the PR. |
fb8d0f3
to
083ea16
Compare
Ok, thank you for pointing out the issues with the tests. We tried using sd and guided decoding together and were surprised that it didn’t work. Anyway, I’m happy if this code helps you. At least, it works well for our case. |
@@ -12,21 +12,26 @@ | |||
from vllm.outputs import RequestOutput | |||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams | |||
|
|||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you keep this mdoel? is there a specific reason for using a larger models for ci?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, my bad.
I thought the 7B model was used initially. Reverting to the 1.5B.
@pytest.fixture(scope="module", params=["autoregressive", "speculative"]) | ||
def llm(request): | ||
|
||
def get_llm_kwargs(mode: str): | ||
if mode == "regular": | ||
return {} | ||
return { | ||
# the model with fixed vocabulary size | ||
"speculative_model": "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft", | ||
"num_speculative_tokens": 3, | ||
} | ||
if mode == "autoregressive": | ||
llm_kwargs = {} | ||
elif mode == "speculative": | ||
llm_kwargs = { | ||
# the model with fixed vocabulary size | ||
"speculative_model": "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft", | ||
"num_speculative_tokens": 3, | ||
} | ||
else: | ||
raise ValueError(f"Unsupported LLM mode: {mode}") | ||
|
||
return llm_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to reduce the amount of change that is not relevant to the PR.
I don't see the diff here. Given that it is a pytest fixture, there is no need to raise exception here.
if the test uses wrong, params, it won't even run the test.
@@ -306,18 +318,28 @@ def __call__(self, input_ids: list[int], | |||
self._ensure_ctx() | |||
|
|||
if len(self.matchers) == 0: | |||
max_rollback_tokens = (self.config.num_lookahead_slots + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: why do we need to append 1 to the num_lookahead_slots here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double-checked this place; it's safe to remove this +1
Initially, it was a workaround for the bonus token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg thanks for the explanation.
def get_llm_kwargs(mode: str): | ||
if mode == "autoregressive": | ||
return {} | ||
return { | ||
# the model with fixed vocabulary size | ||
"speculative_model": "Qwen/Qwen2.5-0.5B-Instruct", | ||
"num_speculative_tokens": 3, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change this to be an elif mode == "speculative":
and raise an exception in the else case saying unsupported mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The opposite suggestion in the comment: #13823 (comment)
:)
@pytest.fixture(scope="module") | ||
def llm(): | ||
@pytest.fixture(scope="module", params=["autoregressive", "speculative"]) | ||
def llm(request): | ||
|
||
def get_llm_kwargs(mode: str): | ||
if mode == "autoregressive": | ||
return {} | ||
return { | ||
# the model with fixed vocabulary size | ||
"speculative_model": "Qwen/Qwen2.5-0.5B-Instruct", | ||
"num_speculative_tokens": 3, | ||
} | ||
|
||
test_llm_kwargs = get_llm_kwargs(request.param) | ||
# pytest caches the fixture so we use weakref.proxy to | ||
# enable garbage collection | ||
llm = LLM(model=MODEL_NAME, max_model_len=1024) | ||
llm = LLM(model=MODEL_NAME, max_model_len=1024, **test_llm_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand how this works - will this fixture now run all of the tests in this file for each entry in params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It first runs all tests with "autoregressive" and then loads the "speculative" model and runs all the tests in the file with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually for pytest fixture this would run both cases, so no need for exception
def test_use_llms(llm):
...
then there would be two tests test_use_llms[llm_autoregressive]
and test_use_llms[llm_speculative]
https://docs.pytest.org/en/6.2.x/fixture.html#parametrizing-fixtures
a299146
to
845a47f
Compare
Hi @mgoin @aarnphm Do I need to do anything else on my end? Are you waiting for me to make any changes to the code? |
25a877b
to
dd71c5f
Compare
hmm it seems like the test failure is not related? |
Can you rename the PR title accordingly? to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like the tests failure is not related, but the tests for this passes so LGTM.
@benchislett you might be interested in this PR |
Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
dd71c5f
to
6cb4eb3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns about this PR in its current state. It seems incompatible with a number of core features (ngram speculative decoding, MQA scoring/enforce_eager), and the correctness of the rollback logic seems fragile.
@@ -71,6 +71,12 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, | |||
assert seq_group.prompt_logprob_indices == [] # No prompt | |||
assert seq_group.sample_indices == [i] # Simple | |||
|
|||
for seq_id in seq_group.seq_ids: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the purpose of this code is to pythonize the token ids so that they are available to the guided decoding logits processor. This seems like an expensive serialization. Have you profiled any performance impact of this code?
As a safeguard, could this code be executed only when guided decoding is enabled for one or more requests in the batch? Could the copy even be limited to only apply to requests with guided decoding enabled?
@pytest.fixture(scope="module") | ||
def llm(): | ||
@pytest.fixture(scope="module", params=["autoregressive", "speculative"]) | ||
def llm(request): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I received a single test failure when running this test file on your branch. Below is the full log. It seems to indicate a guidance failure.
_________________________________________________________________________________________________________________________________________ test_json_with_any_whitespace_disabled[speculative] _________________________________________________________________________________________________________________________________________
llm = <weakproxy at 0x7f57833a1850 to LLM at 0x7f57919e88f0>
@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):
class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
# Parse to verify it is valid JSON
> parsed_json = json.loads(generated_text)
tests/entrypoints/llm/test_guided_generate.py:387:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/lib/python3.12/json/__init__.py:346: in loads
return _default_decoder.decode(s)
/usr/lib/python3.12/json/decoder.py:337: in decode
obj, end = self.raw_decode(s, idx=_w(s, 0).end())
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <json.decoder.JSONDecoder object at 0x7f5923552570>, s = '{"clarifying_question": "", "cost_per_serving": ">$5", "calories": "}} Can you provide more context or clarify your q...onTemArray<MultipartUploadRequest<TemData,Versions.MLVersionInfo,urlMap>", "kitchen_ids": "$nil", "holiday_ids": "{}"}', idx = 0
def raw_decode(self, s, idx=0):
"""Decode a JSON document from ``s`` (a ``str`` beginning with
a JSON document) and return a 2-tuple of the Python
representation and the index in ``s`` where the document ended.
This can be used to decode a JSON document from a string that may
have extraneous data at the end.
"""
try:
> obj, end = self.scan_once(s, idx)
E json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)
/usr/lib/python3.12/json/decoder.py:353: JSONDecodeError
-------------------------------------------------------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:04<00:00, 4.55s/it, est. speed input: 8.79 toks/s, output: 136.43 toks/s]
========================================================================================================================================================== warnings summary ===========================================================================================================================================================
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[autoregressive-outlines]
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-outlines]
/home/benchislett/Repos/vllm/.venv/lib/python3.12/site-packages/outlines/fsm/guide.py:110: UserWarning: Outlines' public *community-contributed* CFG structured generation is experimental. Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer
warnings.warn(
tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[autoregressive]
tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[speculative]
/home/benchislett/Repos/vllm/vllm/entrypoints/llm.py:462: DeprecationWarning: guided_options_request is deprecated, use SamplingParams.guided_decoding instead
self._validate_and_add_requests(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================================================================= short test summary info =======================================================================================================================================================
FAILED tests/entrypoints/llm/test_guided_generate.py::test_json_with_any_whitespace_disabled[speculative] - json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)
i] >= len(input_ids): | ||
diff = self.num_processed_tokens[i] - len(input_ids) + 1 | ||
self.num_processed_tokens[i] -= diff | ||
self.matchers[i].rollback(diff) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain your justification for the correctness of this code? Even after some studying I cannot convince myself that it will always work in the common case (batch expansion with a well-behaved guided draft model)
My understanding is that rolling back a number of states until the length of the matcher's accepted sequence matches the length of input_ids
will ensure that the "sampled_token" is always being accepted at the right position in the state machine. However, it is not clear that this avoids cases where the wrong prefix sequence is taken. The correctness of this code seems to rely on a number of external factors such as the order of the sequences following batch expansion. For example, if the shortest sequence in the expanded batch is not processed first, then at least one of the draft tokens will never be rolled-back from the matcher state and will corrupt the subsequent guidance if it is rejected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code seems to be incompatible with the MQA scorer, which is used when enforce-eager is True. I'm not sure if a small tweak could make this compatible or if a larger rework is required.
To test this, I added enforce_eager=True
to the test_guided_generate.py
LLM fixture. Over 50% of the tests in that file became failures. (I believe that corresponds to nearly all of the speculative tests). Below is the output logs from the failed tests:
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[autoregressive-xgrammar] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1497 column 1 (char 4786)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[autoregressive-xgrammar] - AssertionError: assert False
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-outlines] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-lm-format-enforcer] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-xgrammar] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1 column 19 (char 18)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Unterminated string starting at: line 2 column 1 (char 3)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-xgrammar] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1 column 17 (char 16)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-xgrammar] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 3 (char 2)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Extra data: line 2 column 1 (char 4)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-xgrammar] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-outlines] - RuntimeError: [11:38:43] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-xgrammar] - RuntimeError: [11:38:43] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-outlines] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-lm-format-enforcer] - RuntimeError: [11:38:44] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-xgrammar] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-outlines] - RuntimeError: [11:38:48] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-xgrammar] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_options_request_deprecation_warning[speculative] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-outlines] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-xgrammar] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_json_with_any_whitespace_disabled[speculative] - AssertionError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this code is also incompatible with ngram-speculative decoding, as the proposed tokens in that case do not always follow guidance and therefore need to be manually verified before being passed into accept_token
. This will also need some additional special handling to be processed correctly. For additional details, see this comment on the guided decoding PR for V1
This PR was created by the Nebius team.
The main focus of this PR is to fix guided generation for speculative decoding. We found that when using the xGrammar backend with speculative decoding, vLLM crashes here. This PR addresses the issue by using a rollback mechanism in xGrammar.