Skip to content

Commit

Permalink
Optimize inifinity generation
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 22, 2023
1 parent 708ddfe commit e7e7a8c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def decode_one_token(

# Disable <s> and </s> tokens for codebooks
if model.config.num_codebooks != 0:
logits.codebook_logits[:, :, :, :2] = -float("Inf")
logits.codebook_logits[:, :, :, :1] = -float("Inf")

for i in range(model.config.num_codebooks):
codebooks.append(
Expand Down Expand Up @@ -194,7 +194,7 @@ def decode_n_tokens(
)

# TODO: use tokenizer's eos
if (cur_token[0, 0, -1] == eos_token_id).any():
if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
break

return previous_tokens[:, : i + 1]
Expand Down

0 comments on commit e7e7a8c

Please sign in to comment.