diff --git a/tools/llama/generate.py b/tools/llama/generate.py index e7b5cce7..c84a69be 100644 --- a/tools/llama/generate.py +++ b/tools/llama/generate.py @@ -163,7 +163,7 @@ def decode_n_tokens( **sampling_kwargs, ): previous_tokens = torch.zeros( - (model.config.num_codebooks + 1, num_new_tokens), + (model.config.num_codebooks + 1, model.config.max_seq_len), dtype=torch.int, device=cur_token.device, )