Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Feb 7, 2024
1 parent 1c98e44 commit 8fa1df0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 5 additions & 1 deletion src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def main(
raise ValueError("The checkpoint path should point to a YAML file")

nanotron_config: config_cls = get_config_from_file(
local_config_path, config_class=config_cls, model_config_class=model_config_cls
local_config_path,
config_class=config_cls,
model_config_class=model_config_cls,
skip_unused_config_keys=True,
skip_null_keys=True,
)

if lighteval_config_path:
Expand Down
10 changes: 5 additions & 5 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def greedy_until(
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
for request in requests:
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
request.stop_sequence = request.stop_sequence + (self.tokenizer.eos_token,)
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits)
Expand All @@ -1134,8 +1134,8 @@ def greedy_until(
dataset.split_start = subset_start
dataset.split_end = min(subset_start + subset_length, total_length)

context_enc = dataset[0].tokenized_context
max_gen = max(item.generation_size for item in dataset)
context_enc = dataset[0][1].tokenized_context
max_gen = max(item[1].generation_size for item in dataset)
max_input_length = min(len(context_enc) + max_gen, self.max_length)
batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def greedy_until(
)
iteration_start_time = time.time()
example_index, batch_data = zip(*all_batch)
context = [c.context for c in batch_data]
context = [c.tokenized_context for c in batch_data]
# we take the longest asked generation in the batch
# Multiple request may have different max generation length
max_tokens = max(d.generation_size for d in batch_data) # d[1][1]
Expand Down Expand Up @@ -1244,7 +1244,7 @@ def greedy_until(
):
# Ensure the generated responses do not contain the stop sequences.
decoded_response = self.tokenizer.decode(generation, skip_special_tokens=False)
stop_terms = dataset[example_index][1][1][0]
stop_terms = dataset[example_index][1].stop_sequence
for stop_term in stop_terms:
decoded_response = decoded_response.split(stop_term)[0]
# partial caching
Expand Down

0 comments on commit 8fa1df0

Please sign in to comment.