Skip to content

Commit

Permalink
improve default stop of dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed May 24, 2024
1 parent 7e6ea5c commit 50bbbd5
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 9 deletions.
5 changes: 4 additions & 1 deletion utilization/dataset/agieval_cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ class Agieval_cot(GenerationDataset):
supported_cot = ["base"]

def init_arguments(self):
self.extra_model_args = dict(stop=["\n"]) if self.cot is None else dict()
if self.cot is None:
# when using chain-of-thought, responses might be in multiple lines
self.extra_model_args["stop"] = ["\n"]

text = ""
text += "gen" if self.subset_name in AGIEVAL_NO_LETTER_CHOICE_TASKS else "mcq"
text += "_zh" if self.subset_name in AGIEVAL_ZH_PROMPT_TASKS else "_en"
Expand Down
4 changes: 3 additions & 1 deletion utilization/dataset/bbh.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ class Bbh(GenerationDataset):

def init_arguments(self):
self.bbh_instruction = BBH_PROMPTS[self.subset_name]
self.extra_model_args = dict(stop=["\n"]) if self.cot is None else dict()
if self.cot is None:
# when using chain-of-thought, responses might be in multiple lines
self.extra_model_args["stop"] = ["\n"]

def format_instance(self, instance):
target = instance["answer"]
Expand Down
3 changes: 2 additions & 1 deletion utilization/dataset/gaokao.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class Gaokao(GenerationDataset):

def init_arguments(self):
self.gaokao_instruction = GAOKAO_PROMPTS[self.subset_name]
self.extra_model_args = dict(temperature=0.3, max_tokens=4096)
self.extra_model_args["temperature"] = 0.3
self.extra_model_args["max_tokens"] = 4096
# According to https://github.com/OpenLMLab/GAOKAO-Bench/blob/main/Models/openai_gpt4.py
# We use temperature=0.3 and max_tokens=4096

Expand Down
9 changes: 5 additions & 4 deletions utilization/dataset/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class Gsm8k(GenerationDataset):
_extract_numbers = re.compile(r"[-+]?\d*\.\d+|\d+")

def init_arguments(self):
if self.model_type == 'base':
self.extra_model_args['stop'] = ['\n']
if self.cot is None:
# when using chain-of-thought, responses might be in multiple lines
self.extra_model_args["stop"] = ["\n"]

def load_raw_dataset(self, dataset_path, subset_name, evaluation_set, example_set):
super().load_raw_dataset(dataset_path, subset_name, evaluation_set, example_set)
Expand All @@ -39,7 +40,7 @@ def load_raw_dataset(self, dataset_path, subset_name, evaluation_set, example_se
self.example_data = LEAST_TO_MOST_EXAMPLARS
elif self.cot == 'pal':
self.example_data = PAL_EXAMPLARS
self.instruction = "Let's use python to solve math problems. Here are some examples how to do it."
self.instruction = "Let's use python to solve math problems. Here are some examples how to do it.\n\nQuestion: {{question.replace('\n', ' ')}}\nAnswer:"

def post_processing(self, predictions):
new_predictions = []
Expand Down Expand Up @@ -74,7 +75,7 @@ def post_processing(self, predictions):
def format_instance(self, instance):

# remove decimal seperators
instance["answer"] = ' ' + self._decimal_separator.sub(r"\1\2", instance["answer"])
instance["answer"] = ' ' + self._decimal_separator.sub("", instance["answer"])

# few-shot examples might not contain "####"
if "####" in instance["answer"]:
Expand Down
4 changes: 3 additions & 1 deletion utilization/dataset/icl_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def global_entropy_ordering_strategy(indices, labels, example_dataset, call_mode
return list(best_perm)


def ape(example_dataset, eval_dataset, call_model, api_key):
def ape(example_dataset, eval_dataset, call_model):
"""
generate instructions using APE
Expand All @@ -87,6 +87,8 @@ def ape(example_dataset, eval_dataset, call_model, api_key):
List[str]: results of likelihood evaluation
List[float]: scores based on log probability
"""
import openai
api_key = openai.api_key

class ModelArguments:

Expand Down
3 changes: 2 additions & 1 deletion utilization/dataset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Math(GenerationDataset):

def init_arguments(self):
if self.model_type == 'base':
self.extra_model_args['stop'] = ['\n\n']
# when evaluating base model, responses might be in multiple lines
self.extra_model_args.get("stop", []).append("\n\n")

@staticmethod
def normalize_final_answer(final_answer: str) -> str:
Expand Down

0 comments on commit 50bbbd5

Please sign in to comment.