Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V0][Fix] structured decoding compatibility with speculative decoding #13823

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

southfreebird
Copy link

@southfreebird southfreebird commented Feb 25, 2025

This PR was created by the Nebius team.

The main focus of this PR is to fix guided generation for speculative decoding. We found that when using the xGrammar backend with speculative decoding, vLLM crashes here. This PR addresses the issue by using a rollback mechanism in xGrammar.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@hmellor
Copy link
Member

hmellor commented Feb 25, 2025

Thanks for the contribution!

Please can you make sure that your commits are signed off (instructions here).

Also, some of the entrypoint tests are failing with:

[2025-02-25T12:06:05Z]     def _vocab_size(self) -> int:
[2025-02-25T12:06:05Z]         """Get the vocab size of the model and make sure it's consistent between
[2025-02-25T12:06:05Z]         draft and target workers.
[2025-02-25T12:06:05Z]         """
[2025-02-25T12:06:05Z]         vocab_sizes = [
[2025-02-25T12:06:05Z]             worker.vocab_size
[2025-02-25T12:06:05Z]             for worker in [self.proposer_worker, self.scorer_worker]
[2025-02-25T12:06:05Z]         ]
[2025-02-25T12:06:05Z] >       assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
[2025-02-25T12:06:05Z] E       AssertionError

Which appears to be relevant.

@mgoin
Copy link
Member

mgoin commented Feb 25, 2025

Oh this is interesting, I did not consider the need for rollback until this case. Thanks for your work. I think it is crucial to add a test using guided decoding and speculation together since AFAIK we haven't used these together

@mgoin
Copy link
Member

mgoin commented Feb 25, 2025

@russellb @aarnphm

@aarnphm
Copy link
Contributor

aarnphm commented Feb 25, 2025

I will look into this td for the v1. But thanks for making the PR.

@southfreebird southfreebird force-pushed the feature/speculative-decoding-and-guided-output-fix branch 2 times, most recently from fb8d0f3 to 083ea16 Compare February 25, 2025 13:35
@southfreebird
Copy link
Author

Ok, thank you for pointing out the issues with the tests.

We tried using sd and guided decoding together and were surprised that it didn’t work. Anyway, I’m happy if this code helps you. At least, it works well for our case.

@@ -12,21 +12,26 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you keep this mdoel? is there a specific reason for using a larger models for ci?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my bad.
I thought the 7B model was used initially. Reverting to the 1.5B.

Comment on lines 19 to 34
@pytest.fixture(scope="module", params=["autoregressive", "speculative"])
def llm(request):

def get_llm_kwargs(mode: str):
if mode == "regular":
return {}
return {
# the model with fixed vocabulary size
"speculative_model": "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft",
"num_speculative_tokens": 3,
}
if mode == "autoregressive":
llm_kwargs = {}
elif mode == "speculative":
llm_kwargs = {
# the model with fixed vocabulary size
"speculative_model": "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft",
"num_speculative_tokens": 3,
}
else:
raise ValueError(f"Unsupported LLM mode: {mode}")

return llm_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to reduce the amount of change that is not relevant to the PR.

I don't see the diff here. Given that it is a pytest fixture, there is no need to raise exception here.

if the test uses wrong, params, it won't even run the test.

@@ -306,18 +318,28 @@ def __call__(self, input_ids: list[int],
self._ensure_ctx()

if len(self.matchers) == 0:
max_rollback_tokens = (self.config.num_lookahead_slots +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: why do we need to append 1 to the num_lookahead_slots here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double-checked this place; it's safe to remove this +1
Initially, it was a workaround for the bonus token

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg thanks for the explanation.

Comment on lines +22 to +30
def get_llm_kwargs(mode: str):
if mode == "autoregressive":
return {}
return {
# the model with fixed vocabulary size
"speculative_model": "Qwen/Qwen2.5-0.5B-Instruct",
"num_speculative_tokens": 3,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to be an elif mode == "speculative": and raise an exception in the else case saying unsupported mode

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The opposite suggestion in the comment: #13823 (comment)
:)

Comment on lines -19 to +35
@pytest.fixture(scope="module")
def llm():
@pytest.fixture(scope="module", params=["autoregressive", "speculative"])
def llm(request):

def get_llm_kwargs(mode: str):
if mode == "autoregressive":
return {}
return {
# the model with fixed vocabulary size
"speculative_model": "Qwen/Qwen2.5-0.5B-Instruct",
"num_speculative_tokens": 3,
}

test_llm_kwargs = get_llm_kwargs(request.param)
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)
llm = LLM(model=MODEL_NAME, max_model_len=1024, **test_llm_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand how this works - will this fixture now run all of the tests in this file for each entry in params?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It first runs all tests with "autoregressive" and then loads the "speculative" model and runs all the tests in the file with it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually for pytest fixture this would run both cases, so no need for exception

def test_use_llms(llm):
    ...

then there would be two tests test_use_llms[llm_autoregressive] and test_use_llms[llm_speculative]

https://docs.pytest.org/en/6.2.x/fixture.html#parametrizing-fixtures

@southfreebird southfreebird force-pushed the feature/speculative-decoding-and-guided-output-fix branch from a299146 to 845a47f Compare February 26, 2025 21:38
@southfreebird
Copy link
Author

Hi @mgoin @aarnphm
I rebased the PR and resolved the conflicts.

Do I need to do anything else on my end? Are you waiting for me to make any changes to the code?
I'm not sure if I need to fix this: #13823 (comment)

@southfreebird southfreebird force-pushed the feature/speculative-decoding-and-guided-output-fix branch from 25a877b to dd71c5f Compare February 26, 2025 23:56
@aarnphm
Copy link
Contributor

aarnphm commented Feb 27, 2025

hmm it seems like the test failure is not related?

@aarnphm
Copy link
Contributor

aarnphm commented Feb 27, 2025

Can you rename the PR title accordingly?

to [V0][Fix] structured decoding compatibility with speculative decoding

Copy link
Contributor

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the tests failure is not related, but the tests for this passes so LGTM.

@southfreebird southfreebird changed the title Fix xgrammar decoding for speculative decoding [V0][Fix] structured decoding compatibility with speculative decoding Feb 27, 2025
@aarnphm
Copy link
Contributor

aarnphm commented Feb 28, 2025

@benchislett you might be interested in this PR

Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
Signed-off-by: southfreebird <[email protected]>
@southfreebird southfreebird force-pushed the feature/speculative-decoding-and-guided-output-fix branch from dd71c5f to 6cb4eb3 Compare February 28, 2025 10:02
Copy link
Contributor

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns about this PR in its current state. It seems incompatible with a number of core features (ngram speculative decoding, MQA scoring/enforce_eager), and the correctness of the rollback logic seems fragile.

@@ -71,6 +71,12 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple

for seq_id in seq_group.seq_ids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the purpose of this code is to pythonize the token ids so that they are available to the guided decoding logits processor. This seems like an expensive serialization. Have you profiled any performance impact of this code?

As a safeguard, could this code be executed only when guided decoding is enabled for one or more requests in the batch? Could the copy even be limited to only apply to requests with guided decoding enabled?

@pytest.fixture(scope="module")
def llm():
@pytest.fixture(scope="module", params=["autoregressive", "speculative"])
def llm(request):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I received a single test failure when running this test file on your branch. Below is the full log. It seems to indicate a guidance failure.

_________________________________________________________________________________________________________________________________________ test_json_with_any_whitespace_disabled[speculative] _________________________________________________________________________________________________________________________________________

llm = <weakproxy at 0x7f57833a1850 to LLM at 0x7f57919e88f0>

    @pytest.mark.skip_global_cleanup
    def test_json_with_any_whitespace_disabled(llm):
    
        class ResponseSchema(BaseModel):
            clarifying_question: str
            cost_per_serving: str
            calories: str
            type_dish_ids: str
            type_meal_ids: str
            product_ids: list[str]
            exclude_product_ids: list[str]
            allergen_ids: list[str]
            total_cooking_time: str
            kitchen_ids: str
            holiday_ids: str
    
        # Note: Without this setting, the response is sometimes full of `\n`
        # for some models. This option prevents that.
        guided_decoding_backend = 'xgrammar:disable-any-whitespace'
    
        schema = ResponseSchema.model_json_schema()
        guided_params = GuidedDecodingParams(json=schema,
                                             backend=\
                                               guided_decoding_backend)
        sampling_params = SamplingParams(max_tokens=2000,
                                         frequency_penalty=0,
                                         presence_penalty=-1.1,
                                         repetition_penalty=1.3,
                                         guided_decoding=guided_params)
    
        prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
                  "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
                  "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
        outputs = llm.generate(prompts=prompt,
                               sampling_params=sampling_params,
                               use_tqdm=True)
    
        assert outputs is not None
    
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
    
            generated_text = output.outputs[0].text
            assert generated_text is not None
            assert "\n" not in generated_text
    
            # Parse to verify it is valid JSON
>           parsed_json = json.loads(generated_text)

tests/entrypoints/llm/test_guided_generate.py:387: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/lib/python3.12/json/__init__.py:346: in loads
    return _default_decoder.decode(s)
/usr/lib/python3.12/json/decoder.py:337: in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <json.decoder.JSONDecoder object at 0x7f5923552570>, s = '{"clarifying_question": "", "cost_per_serving": ">$5", "calories": "}} Can you provide more context or clarify your q...onTemArray<MultipartUploadRequest<TemData,Versions.MLVersionInfo,urlMap>", "kitchen_ids": "$nil", "holiday_ids": "{}"}', idx = 0

    def raw_decode(self, s, idx=0):
        """Decode a JSON document from ``s`` (a ``str`` beginning with
        a JSON document) and return a 2-tuple of the Python
        representation and the index in ``s`` where the document ended.
    
        This can be used to decode a JSON document from a string that may
        have extraneous data at the end.
    
        """
        try:
>           obj, end = self.scan_once(s, idx)
E           json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)

/usr/lib/python3.12/json/decoder.py:353: JSONDecodeError
-------------------------------------------------------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------------------------------------------------------
Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.55s/it, est. speed input: 8.79 toks/s, output: 136.43 toks/s]
========================================================================================================================================================== warnings summary ===========================================================================================================================================================
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[autoregressive-outlines]
tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-outlines]
  /home/benchislett/Repos/vllm/.venv/lib/python3.12/site-packages/outlines/fsm/guide.py:110: UserWarning: Outlines' public *community-contributed* CFG structured generation is experimental. Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer
    warnings.warn(

tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[autoregressive]
tests/entrypoints/llm/test_guided_generate.py::test_validation_against_both_guided_decoding_options[speculative]
  /home/benchislett/Repos/vllm/vllm/entrypoints/llm.py:462: DeprecationWarning: guided_options_request is deprecated, use SamplingParams.guided_decoding instead
    self._validate_and_add_requests(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================================================================= short test summary info =======================================================================================================================================================
FAILED tests/entrypoints/llm/test_guided_generate.py::test_json_with_any_whitespace_disabled[speculative] - json.decoder.JSONDecodeError: Invalid control character at: line 1 column 580 (char 579)

i] >= len(input_ids):
diff = self.num_processed_tokens[i] - len(input_ids) + 1
self.num_processed_tokens[i] -= diff
self.matchers[i].rollback(diff)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain your justification for the correctness of this code? Even after some studying I cannot convince myself that it will always work in the common case (batch expansion with a well-behaved guided draft model)

My understanding is that rolling back a number of states until the length of the matcher's accepted sequence matches the length of input_ids will ensure that the "sampled_token" is always being accepted at the right position in the state machine. However, it is not clear that this avoids cases where the wrong prefix sequence is taken. The correctness of this code seems to rely on a number of external factors such as the order of the sequences following batch expansion. For example, if the shortest sequence in the expanded batch is not processed first, then at least one of the draft tokens will never be rolled-back from the matcher state and will corrupt the subsequent guidance if it is rejected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems to be incompatible with the MQA scorer, which is used when enforce-eager is True. I'm not sure if a small tweak could make this compatible or if a larger rework is required.

To test this, I added enforce_eager=True to the test_guided_generate.py LLM fixture. Over 50% of the tests in that file became failures. (I believe that corresponds to nearly all of the speculative tests). Below is the output logs from the failed tests:

FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[autoregressive-xgrammar] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1497 column 1 (char 4786)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[autoregressive-xgrammar] - AssertionError: assert False
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-outlines] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-lm-format-enforcer] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_regex[speculative-xgrammar] - AssertionError: assert None is not None
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1 column 19 (char 18)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Unterminated string starting at: line 2 column 1 (char 3)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_completion[speculative-xgrammar] - json.decoder.JSONDecodeError: Expecting ',' delimiter: line 1 column 17 (char 16)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_complex_json_completion[speculative-xgrammar] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 2 (char 1)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-outlines] - json.decoder.JSONDecodeError: Unterminated string starting at: line 1 column 3 (char 2)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-lm-format-enforcer] - json.decoder.JSONDecodeError: Extra data: line 2 column 1 (char 4)
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_definition_json_completion[speculative-xgrammar] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-outlines] - RuntimeError: [11:38:43] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_enum_json_completion[speculative-xgrammar] - RuntimeError: [11:38:43] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-outlines] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-lm-format-enforcer] - RuntimeError: [11:38:44] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_choice_completion[speculative-xgrammar] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-outlines] - RuntimeError: [11:38:48] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_grammar[speculative-xgrammar] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_options_request_deprecation_warning[speculative] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-outlines] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-lm-format-enforcer] - AssertionError
FAILED tests/entrypoints/llm/test_guided_generate.py::test_guided_json_object[speculative-xgrammar] - RuntimeError: [11:38:49] /workspace/cpp/grammar_matcher.cc:705: Check failed: (num_tokens <= static_cast<int>(token_length_history.size())) is false: Intended to rollback 1 tokens, but only the last 0 steps of history are saved
FAILED tests/entrypoints/llm/test_guided_generate.py::test_json_with_any_whitespace_disabled[speculative] - AssertionError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this code is also incompatible with ngram-speculative decoding, as the proposed tokens in that case do not always follow guidance and therefore need to be manually verified before being passed into accept_token. This will also need some additional special handling to be processed correctly. For additional details, see this comment on the guided decoding PR for V1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants