Skip to content

Commit

Permalink
[Sampler] Fix stop strings offset for speculative decoding (#1719)
Browse files Browse the repository at this point in the history
Fix of broken in
#1676 behavior.
In case of speculative decoding we should match step_substring:

![image](https://github.com/user-attachments/assets/99fb9e64-37c9-4704-a90e-82e8a74baaaa)
  • Loading branch information
iefode authored Feb 12, 2025
1 parent 5e21a4a commit c5e484c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
13 changes: 10 additions & 3 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,16 @@ struct MatchStopStringResult {
MatchStopStringResult match_stop_string(Tokenizer& tokenizer,
const TokenIds& generated_tokens,
const std::pair<size_t, std::set<std::string>>& stop_strings,
bool is_include_to_output) {
bool is_include_to_output,
size_t draft_generated_tokens = 0) {
MatchStopStringResult result;
if (generated_tokens.size() >= stop_strings.first) {
size_t offset = generated_tokens.size() - stop_strings.first;
// draft_generated_tokens is to handle case with >= 1 generated tokens per step
size_t offset = generated_tokens.size() - draft_generated_tokens;
if (offset < stop_strings.first) {
return result;
}
offset -= stop_strings.first;
TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end());
std::string decoded_buffer = tokenizer.decode(buffer);
for (const auto& stop_string : stop_strings.second) {
Expand Down Expand Up @@ -567,7 +573,8 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen

if (!sampling_params.stop_strings.empty()) {
auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id());
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output);
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings,
sampling_params.include_stop_str_in_output, sequence_group->get_num_tokens_to_validate());
if (match_result.is_matched) {
running_sequence->remove_last_tokens(match_result.to_remove);

Expand Down
28 changes: 16 additions & 12 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,30 +376,32 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type):
model_id : str = "facebook/opt-125m"
opt_model, hf_tokenizer = get_hugging_face_models(model_id)

models_path : Path = tmp_path / "t_streaming" / model_id
models_path : Path = tmp_path / model_id
convert_models(opt_model, hf_tokenizer, models_path)

generation_config = GenerationConfig()
pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
pipe, input, generation_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)

it_cnt = 0
def py_streamer(py_str: str):
nonlocal it_cnt
it_cnt += 1
return False

try:
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
except Exception:
assert True
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)

del pipe
rmtree(models_path)

assert it_cnt > 0

@pytest.mark.parametrize("pipeline_type", ["continuous_batching", "speculative_decoding", "prompt_lookup_decoding", "llm_pipeline"])
@pytest.mark.precommit
def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type):
model_id : str = "facebook/opt-125m"
opt_model, hf_tokenizer = get_hugging_face_models(model_id)

models_path : Path = tmp_path / "t_streaming" / model_id
models_path : Path = tmp_path / model_id
convert_models(opt_model, hf_tokenizer, models_path)

generation_config = GenerationConfig()
Expand All @@ -408,13 +410,15 @@ def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type)

pipe, input, generation_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)

it_cnt = 0
def py_streamer(py_str: str):
raise Exception("Streamer was called")
nonlocal it_cnt
it_cnt += 1
return False

try:
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
except Exception:
assert False
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)

del pipe
rmtree(models_path)

assert it_cnt == 0

0 comments on commit c5e484c

Please sign in to comment.