Skip to content

Commit

Permalink
fix position_id init for qwen2
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 15, 2025
1 parent 00e6bf3 commit acfd0ce
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,11 @@ def _qwen2_model_forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
Expand Down

0 comments on commit acfd0ce

Please sign in to comment.