Skip to content

Commit

Permalink
fix beam search in seq2seq (#1111)
Browse files Browse the repository at this point in the history
* fix beam search in seq2seq

* add tests
  • Loading branch information
eaidova authored Jan 15, 2025
1 parent 248aabd commit a11c6c8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
7 changes: 3 additions & 4 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def get_encoder(self):
return self.encoder

def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
self.decoder._reorder_cache(past, beam_idx)
return self.decoder._reorder_cache(past, beam_idx)

def reshape(self, batch_size: int, sequence_length: int):
"""
Expand Down Expand Up @@ -627,6 +627,7 @@ def forward(
if self.stateful and past_key_values is None:
self.request.reset_state()
self._past_length = 0
self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)

if past_key_values is not None and not self.stateful:
# Flatten the past_key_values
Expand Down Expand Up @@ -661,7 +662,6 @@ def forward(
inputs["beam_idx"] = (
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32)
)

# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
Expand Down Expand Up @@ -1016,7 +1016,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio
auto_model_class = WhisperForConditionalGeneration

# force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods
prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation
generate = WhisperForConditionalGeneration.generate

@classmethod
Expand Down Expand Up @@ -1083,7 +1082,7 @@ def prepare_inputs_for_generation(

past_length = 0
if past_key_values is not None:
self.decoder._get_past_length(past_key_values)
past_length = self.decoder._get_past_length(past_key_values)

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
Expand Down
34 changes: 32 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,21 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**tokens, **decoder_inputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=2,
do_sample=False,
eos_token_id=None,
)

set_seed(SEED)
generated_tokens = transformers_model.generate(**tokens, generation_config=gen_config)
set_seed(SEED)
ov_generated_tokens = ov_model.generate(**tokens, generation_config=gen_config)

self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))

del transformers_model
del ov_model

Expand Down Expand Up @@ -2355,12 +2370,12 @@ def test_compare_to_transformers(self, model_arch):

processor = get_preprocessor(model_id)
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt")
pt_features = processor.feature_extractor(data, return_tensors="pt")
decoder_start_token_id = transformers_model.config.decoder_start_token_id
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs)
transformers_outputs = transformers_model(**pt_features, **decoder_inputs)

for input_type in ["pt", "np"]:
features = processor.feature_extractor(data, return_tensors=input_type)
Expand All @@ -2373,6 +2388,21 @@ def test_compare_to_transformers(self, model_arch):
# Compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))

gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=2,
do_sample=False,
eos_token_id=None,
)

set_seed(SEED)
generated_tokens = transformers_model.generate(**pt_features, generation_config=gen_config)
set_seed(SEED)
ov_generated_tokens = ov_model.generate(**pt_features, generation_config=gen_config)

self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))

del transformers_model
del ov_model
gc.collect()
Expand Down

0 comments on commit a11c6c8

Please sign in to comment.