Skip to content

Commit

Permalink
fix(generation): remove base transformers warpers
Browse files Browse the repository at this point in the history
The logits warpers used for sampling are now included in the list of
processors obtained when calling GenerationMixin._get_logits_processor.
We need to remove them explicitly because we use instead a fused logits
warper (which is 10x faster).
  • Loading branch information
dacorvo committed Jan 6, 2025
1 parent 88c55b3 commit d4084a3
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import GenerationMode

Expand Down Expand Up @@ -154,6 +157,15 @@ def create(

logits_warper = None
if generation_mode == GenerationMode.SAMPLE:
# Remove transformers TopK, TopP and Temperature processors
logits_processor = LogitsProcessorList(
[
p
for p in logits_processor
if not isinstance(p, (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper))
]
)
# We use a fused logits warper instead
logits_warper = FusedLogitsWarper.from_config(generation_config)

return cls(
Expand Down

0 comments on commit d4084a3

Please sign in to comment.