From 8fa1df0b3718c003227b0a8d4dd20ac2a6fab2eb Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Wed, 7 Feb 2024 17:51:00 +0000 Subject: [PATCH] fixes --- src/lighteval/main_nanotron.py | 6 +++++- src/lighteval/models/nanotron_model.py | 10 +++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index b256a608e..95b897c2e 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -54,7 +54,11 @@ def main( raise ValueError("The checkpoint path should point to a YAML file") nanotron_config: config_cls = get_config_from_file( - local_config_path, config_class=config_cls, model_config_class=model_config_cls + local_config_path, + config_class=config_cls, + model_config_class=model_config_cls, + skip_unused_config_keys=True, + skip_null_keys=True, ) if lighteval_config_path: diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index 88ccfef4f..51e682eb6 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -1111,7 +1111,7 @@ def greedy_until( # automatic (variable) batch size detection for vectorization # pull longest context sample from request for request in requests: - request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token] + request.stop_sequence = request.stop_sequence + (self.tokenizer.eos_token,) request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits) @@ -1134,8 +1134,8 @@ def greedy_until( dataset.split_start = subset_start dataset.split_end = min(subset_start + subset_length, total_length) - context_enc = dataset[0].tokenized_context - max_gen = max(item.generation_size for item in dataset) + context_enc = dataset[0][1].tokenized_context + max_gen = max(item[1].generation_size for item in dataset) max_input_length = min(len(context_enc) + max_gen, self.max_length) batch_size = self._get_batch_size( override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size @@ -1172,7 +1172,7 @@ def greedy_until( ) iteration_start_time = time.time() example_index, batch_data = zip(*all_batch) - context = [c.context for c in batch_data] + context = [c.tokenized_context for c in batch_data] # we take the longest asked generation in the batch # Multiple request may have different max generation length max_tokens = max(d.generation_size for d in batch_data) # d[1][1] @@ -1244,7 +1244,7 @@ def greedy_until( ): # Ensure the generated responses do not contain the stop sequences. decoded_response = self.tokenizer.decode(generation, skip_special_tokens=False) - stop_terms = dataset[example_index][1][1][0] + stop_terms = dataset[example_index][1].stop_sequence for stop_term in stop_terms: decoded_response = decoded_response.split(stop_term)[0] # partial caching