From ad7dd41d76a97e3df59ccdf92bf571edbe5ce69e Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Mon, 27 Jan 2025 17:37:06 -0500 Subject: [PATCH] Mixtral fix: match reference with standalone script --- language/mixtral-8x7b/SUT.py | 87 ++++++-------------------------- language/mixtral-8x7b/dataset.py | 29 +++++++++-- 2 files changed, 41 insertions(+), 75 deletions(-) diff --git a/language/mixtral-8x7b/SUT.py b/language/mixtral-8x7b/SUT.py index e12a4318f..600d9e9d2 100644 --- a/language/mixtral-8x7b/SUT.py +++ b/language/mixtral-8x7b/SUT.py @@ -30,11 +30,12 @@ log = logging.getLogger("Mixtral-8x7B-Instruct-v0.1") gen_kwargs = { - "early_stopping": True, - "max_new_tokens": 1024, + # "min_new_tokens": 1, "min_new_tokens": 2, - "num_beams": 1, + "max_new_tokens": 1024, "do_sample": False, + "temperature": None, + "top_p": None, } @@ -238,80 +239,30 @@ def process_queries(self): input_masks_tensor = [] input_len = [] input_dataset = [] + batch_texts = [] + datasets = [] for q in qitem: - input_ids_tensor.append( - pad( - self.data_object.input_ids[q.index], - ( - max_seq_len - - self.data_object.input_lens[q.index], - 0, - 0, - 0, - ), - value=self.tokenizer.pad_token_id, - ) - ) - input_masks_tensor.append( - pad( - self.data_object.attention_masks[q.index], - ( - max_seq_len - - self.data_object.input_lens[q.index], - 0, - 0, - 0, - ), - value=0, - ) - ) + batch_texts.append(self.data_object.input_texts[q.index]) input_len.append(self.data_object.input_lens[q.index]) - # In case we predict code generation, we can specify an # additional stop sequence input_dataset.append( self.data_object.dataset_names[q.index]) - input_ids_tensor = torch.cat(input_ids_tensor) - input_masks_tensor = torch.cat(input_masks_tensor) - assert input_ids_tensor.shape == input_masks_tensor.shape - assert input_ids_tensor.shape[0] <= self.batch_size + batch_ids = self.tokenizer.batch_encode_plus(batch_texts, return_tensors="pt", padding=True) + batch_ids = batch_ids.to(self.device) tik2 = time.time() - logits_processor = LogitsProcessorList( - [StopAfterSequence( - self.tokenizer.eos_token_id, device=self.device)] - ) - for i in range(len(input_ids_tensor)): - ids, masks, dataset = ( - input_ids_tensor[i: i + 1], - input_masks_tensor[i: i + 1], - input_dataset[i], - ) - pred_output_tokens = [] - if dataset == "MBXP": - out = self.model.generate( - input_ids=ids, - attention_mask=masks, - pad_token_id=self.tokenizer.pad_token_id, - logits_processor=logits_processor, - **gen_kwargs, - ) - else: - out = self.model.generate( - input_ids=ids, - attention_mask=masks, - pad_token_id=self.tokenizer.pad_token_id, - **gen_kwargs, - ) - pred_output_tokens.append(out) - pred_output_tokens = torch.cat(pred_output_tokens) + _, length = batch_ids.input_ids.shape + out = self.model.generate(**batch_ids, num_return_sequences=1, **gen_kwargs) + pred_output_tokens = out tik3 = time.time() processed_output = self.data_object.postProcess( pred_output_tokens, - input_seq_lens=input_len, + length=length, query_id_list=query_ids, + dataset_list=input_dataset, ) for i in range(len(qitem)): @@ -342,10 +293,7 @@ def process_queries(self): def load_model(self): self.model = AutoModelForCausalLM.from_pretrained( - self.model_path, - device_map="auto", - low_cpu_mem_usage=True, - torch_dtype=self.amp_dtype, + self.model_path, device_map="auto", trust_remote_code=True ) print("Loaded model") @@ -362,10 +310,7 @@ def load_model(self): pass self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, - model_max_length=1024, - padding_side="left", - use_fast=False, + self.model_path, padding_side="left", trust_remote_code=True ) self.tokenizer.pad_token = self.tokenizer.eos_token diff --git a/language/mixtral-8x7b/dataset.py b/language/mixtral-8x7b/dataset.py index c8268d8d1..0c757ba7d 100644 --- a/language/mixtral-8x7b/dataset.py +++ b/language/mixtral-8x7b/dataset.py @@ -67,6 +67,7 @@ def load_processed_dataset(self): processed_data = pd.read_pickle(self.dataset_path) input_tokens = processed_data["tok_input"] + self.input_texts = processed_data["input"].to_list() self.input_ids = [] self.input_lens = [] @@ -85,12 +86,31 @@ def load_processed_dataset(self): self.dataset_names.append(dataset) print("Finished loading dataset.") + + def remove_trailing_twos(self, lst, eos = 2): + count = 0 + for num in reversed(lst): + if num == eos or num == 0: + count += 1 + else: + break + return lst[:-count] if count > 0 else lst + + + def mbxp_stop(self, lst, stop_tokens = [13, 13940, 28832, 13]): + for i in range(len(lst) - len(stop_tokens) + 1): + if (lst[i:i+len(stop_tokens)] == stop_tokens).all(): + return lst[:i+len(stop_tokens)] + return lst + + def postProcess( self, out_tokens, - input_seq_lens=None, + length=None, query_id_list=None, sample_index_list=None, + dataset_list=None, ): """Postprocesses output prediction""" @@ -106,13 +126,14 @@ def postProcess( """ # Everything is padded to max_len (1024), so prune the input and parse # to numpy - output_seq = out_tokens[:, 1024:].cpu().numpy() + output_seq = out_tokens[:, length:].cpu().numpy() aux_seq = [] assert len(query_id_list) == output_seq.shape[0] for i in range(len(output_seq)): aux = output_seq[i] - while len(output_seq[i]) <= 1: - aux = np.append(aux, self.tokenizer.eos_token_id) + aux = self.remove_trailing_twos(aux) + if (dataset_list[i] == "MBXP"): + aux = self.mbxp_stop(aux) aux_seq.append(aux) output_seq = np.stack(aux_seq)