From 6d34d7b84a1e7d6aebd00e9b91709e3887303b0d Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Wed, 15 Feb 2023 10:15:01 +0000 Subject: [PATCH 01/11] Add MultiWOZ 2.4 DST evaluation --- mwzeval/metrics.py | 12 +++++- mwzeval/utils.py | 105 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index 648b5c7..34bcee3 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -16,6 +16,7 @@ class Evaluator: + _MWZ_VERSION = '22' def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = False): self.bleu = bleu @@ -32,7 +33,7 @@ def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = Fa self.booked_domains = load_booked_domains() if dst: - self.gold_states = load_gold_states() + self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION) def evaluate(self, input_data): normalize_data(input_data) @@ -42,6 +43,15 @@ def evaluate(self, input_data): "richness" : get_richness(input_data) if self.richness else None, "dst" : get_dst(input_data, self.gold_states) if self.dst else None, } + + +class Multiwoz24Evaluator(Evaluator): + _MWZ_VERSION = '24' + + def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False): + if bleu or success or richness: + raise NotImplementedError("bleu, success or richness metrics are not yet implemented for MultiWOZ 2.4.") + super().__init__(bleu=bleu, success=success, richness=richness, dst=dst) def get_bleu(input_data, reference_dialogs): diff --git a/mwzeval/utils.py b/mwzeval/utils.py index e49dedd..5b6d0b4 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -1,6 +1,10 @@ import os import json import urllib.request +from typing import Literal +import zipfile +import io +from collections import defaultdict from mwzeval.normalization import normalize_data @@ -102,13 +106,18 @@ def load_multiwoz22_reference(): return references -def load_gold_states(): +def load_gold_states(mwz_version: Literal['22', '24'] = '22'): dir_path = os.path.dirname(os.path.realpath(__file__)) data_path = os.path.join(dir_path, "data", "gold_states.json") if os.path.exists(data_path): with open(data_path) as f: return json.load(f) - _, states = load_multiwoz22() + if mwz_version == "22": + _, states = load_multiwoz22() + elif mwz_version == "24": + _, states = load_multiwoz24() + else: + raise ValueError("Unsupported MultiWOZ version.") return states @@ -192,3 +201,95 @@ def parse_state(turn): json.dump(states, f, indent=2) return references, states + + +def load_multiwoz24(): + def is_filled(slot_value: str) -> bool: + """Whether a slot value is filled.""" + slot_value = slot_value.lower() + return slot_value and slot_value != "not mentioned" and slot_value != "none" + + def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: + """Get the slot values of a given turn. + + This function is adapted from + google-research/schema_guided_dst/multiwoz/create_data_from_multiwoz.py + + If a slot has multiple values (which are separated by '|'), only the first one is taken. + This is consistant with the approach taken for MultiWOZ 2.2 evaluation. + + Args: + turn: Dictionary of a turn of the MultiWOZ 2.4 dataset + prepend_book: Whether to prepend the string 'book' to slot names for booking slots + + Returns: + {$domain: {$slot_name: $value, ...}, ...} + """ + dialog_states = defaultdict(dict) + for domain_name, values in turn['metadata'].items(): + dialog_states_of_one_domain = {} + + for k, v in values["book"].items(): + # Note: "booked" is not really a state, just booking confirmation + if k == 'booked': + continue + if isinstance(v, list): + for item_dict in v: + new_states = { + (f"book{slot_name}" if prepend_book else slot_name): slot_val + for slot_name, slot_val in item_dict.items() + } + dialog_states_of_one_domain.update(new_states) + if isinstance(v, str) and v: + slot_name = f"book{k}" if prepend_book else k + dialog_states_of_one_domain[slot_name] = v + + new_states = values["semi"] + dialog_states_of_one_domain.update(new_states) + + dialog_states_of_one_domain = { + slot_name: value.split('|')[0] # use the first value + for slot_name, value in dialog_states_of_one_domain.items() + if is_filled(value) + } + if len(dialog_states_of_one_domain) > 0: + dialog_states[domain_name] = dialog_states_of_one_domain + + return dialog_states + + with urllib.request.urlopen( + "https://github.com/smartyfh/MultiWOZ2.4/blob/main/data/MULTIWOZ2.4.zip?raw=true" + ) as url: + print("Downloading MultiWOZ_2.4") + unzipped = zipfile.ZipFile(io.BytesIO(url.read())) + # dialogue_acts = json.loads(unzipped.read('MULTIWOZ2.4/dialogue_acts.json')) + data = json.loads(unzipped.read('MULTIWOZ2.4/data.json')) + + mwz24_data = {} + for dialogue_id, dialogue in data.items(): + parsed_turns = [] + for i, turn in enumerate(dialogue["log"]): + if i % 2 == 0: + continue + state = parse_state(turn) + parsed_turns.append({"response": "", "state": state}) + mwz24_data[dialogue_id.split(".")[0].lower()] = parsed_turns + + normalize_data(mwz24_data) + + references, states = {}, {} + for dialog in mwz24_data: + # references[dialog] = [x["response"] for x in mwz24_data[dialog]] + states[dialog] = [x["state"] for x in mwz24_data[dialog]] + + dir_path = os.path.dirname(os.path.realpath(__file__)) + # reference_path = os.path.join(dir_path, "data", "references", "mwz24.json") + state_path = os.path.join(dir_path, "data", "gold_states.json") + + # with open(reference_path, 'w+') as f: + # json.dump(references, f, indent=2) + + with open(state_path, 'w+') as f: + json.dump(states, f, indent=2) + + return references, states From 9e23b127f5a222c9138faa57d04a621453e29549 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Wed, 15 Feb 2023 14:19:28 +0000 Subject: [PATCH 02/11] Add to slot value normalization --- mwzeval/normalization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mwzeval/normalization.py b/mwzeval/normalization.py index f4c6461..538d7be 100644 --- a/mwzeval/normalization.py +++ b/mwzeval/normalization.py @@ -78,6 +78,8 @@ def type_to_canonical(type_string): return "nightclub" elif type_string == "guest house": return "guesthouse" + elif type_string == "concert hall": + return "concerthall" return type_string def name_to_canonical(name, domain=None): From 75af0f108ce42ed70bcb3c8fefeb1ea030760e43 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Wed, 15 Feb 2023 16:10:43 +0000 Subject: [PATCH 03/11] Fix multi-value extraction --- mwzeval/utils.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/mwzeval/utils.py b/mwzeval/utils.py index 5b6d0b4..726259c 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -209,25 +209,38 @@ def is_filled(slot_value: str) -> bool: slot_value = slot_value.lower() return slot_value and slot_value != "not mentioned" and slot_value != "none" + def get_first_value(values: str) -> str: + """Get the first value if the values string contains multiple.""" + if "|" in values: + values = values.split("|") + elif ">" in values: + values = values.split(">") + elif "<" in values: + values = values.split("<") + else: + values = [values] + return values[0] + def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: """Get the slot values of a given turn. This function is adapted from google-research/schema_guided_dst/multiwoz/create_data_from_multiwoz.py - If a slot has multiple values (which are separated by '|'), only the first one is taken. + If a slot has multiple values (which are separated by '|', '<' or '>'), only the first one is taken. This is consistant with the approach taken for MultiWOZ 2.2 evaluation. Args: turn: Dictionary of a turn of the MultiWOZ 2.4 dataset - prepend_book: Whether to prepend the string 'book' to slot names for booking slots + prepend_book: Whether to prepend the string 'book' to slot names for booking slots. + MultiWOZ 2.2 has the 'book' prefix. Returns: {$domain: {$slot_name: $value, ...}, ...} """ dialog_states = defaultdict(dict) for domain_name, values in turn['metadata'].items(): - dialog_states_of_one_domain = {} + domain_dial_state = {} for k, v in values["book"].items(): # Note: "booked" is not really a state, just booking confirmation @@ -237,23 +250,22 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: for item_dict in v: new_states = { (f"book{slot_name}" if prepend_book else slot_name): slot_val - for slot_name, slot_val in item_dict.items() - } - dialog_states_of_one_domain.update(new_states) + for slot_name, slot_val in item_dict.items() } + domain_dial_state.update(new_states) if isinstance(v, str) and v: slot_name = f"book{k}" if prepend_book else k - dialog_states_of_one_domain[slot_name] = v + domain_dial_state[slot_name] = v new_states = values["semi"] - dialog_states_of_one_domain.update(new_states) + domain_dial_state.update(new_states) - dialog_states_of_one_domain = { - slot_name: value.split('|')[0] # use the first value - for slot_name, value in dialog_states_of_one_domain.items() + domain_dial_state = { + slot_name: get_first_value(value) # use the first value + for slot_name, value in domain_dial_state.items() if is_filled(value) } - if len(dialog_states_of_one_domain) > 0: - dialog_states[domain_name] = dialog_states_of_one_domain + if len(domain_dial_state) > 0: + dialog_states[domain_name] = domain_dial_state return dialog_states From 1abd9f59cf35e670b655f04d25c4c0d88d45f4c4 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Thu, 23 Feb 2023 11:58:34 +0000 Subject: [PATCH 04/11] Add flag for extended normalization --- mwzeval/metrics.py | 12 ++++++++++-- mwzeval/normalization.py | 17 ++++++++++++----- mwzeval/utils.py | 7 +++++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index 34bcee3..f849129 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -35,8 +35,13 @@ def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = Fa if dst: self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION) - def evaluate(self, input_data): - normalize_data(input_data) + def evaluate(self, input_data, use_extended_normalization=False): + """Get evaluations. + + The use_extended_normalization flag enables a more thorough normalization. + Setting the flag to False ensures backwards compatibility. + """ + normalize_data(input_data, extended=use_extended_normalization) return { "bleu" : get_bleu(input_data, self.reference_dialogs) if self.bleu else None, "success" : get_success(input_data, self.database, self.goals, self.booked_domains) if self.success else None, @@ -53,6 +58,9 @@ def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False) raise NotImplementedError("bleu, success or richness metrics are not yet implemented for MultiWOZ 2.4.") super().__init__(bleu=bleu, success=success, richness=richness, dst=dst) + def evaluate(self, input_data): + return super().evaluate(input_data, use_extended_normalization=True) + def get_bleu(input_data, reference_dialogs): """ Get SacreBLEU score between normalized utterances in input data and a set of normalized references. """ diff --git a/mwzeval/normalization.py b/mwzeval/normalization.py index 538d7be..461091b 100644 --- a/mwzeval/normalization.py +++ b/mwzeval/normalization.py @@ -4,8 +4,12 @@ from sacremoses import MosesTokenizer, MosesDetokenizer -def normalize_data(input_data): - """ In-place normalization of raw dictionary with input data. Normalize slot names, slot values, remove plurals and detokenize utterances. """ +def normalize_data(input_data, extended=False): + """ In-place normalization of raw dictionary with input data. Normalize slot names, slot values, remove plurals and detokenize utterances. + + The extended flag is used for a more thorough normalization compared to MultiWOZ 2.2. + Setting the flag to False ensures backwards compatibility. + """ mt, md = MosesTokenizer(lang='en'), MosesDetokenizer(lang='en') slot_name_re = re.compile(r'\[([\w\s\d]+)\](es|s|-s|-es|)') @@ -25,7 +29,7 @@ def normalize_data(input_data): slot = slot.lower().replace(' ', '') if slot == "arriveby": slot = "arrive" elif slot == "leaveat": slot = "leave" - new_state[slot] = normalize_state_slot_value(slot, value) + new_state[slot] = normalize_state_slot_value(slot, value, extended=extended) turn["state"][domain] = new_state @@ -61,12 +65,15 @@ def normalize_slot_name(slot_name): return reverse_slot_name_mapping[slot_name] -def normalize_state_slot_value(slot_name, value): +def normalize_state_slot_value(slot_name, value, extended=False): """ Normalize slot value: 1) replace too distant venue names with canonical values 2) replace too distant food types with canonical values 3) parse time strings to the HH:MM format 4) resolve inconsistency between the database entries and parking and internet slots + + The extended flag is used for a more thorough normalization compared to MultiWOZ 2.2. + Setting the flag to False ensures backwards compatibility. """ def type_to_canonical(type_string): @@ -78,7 +85,7 @@ def type_to_canonical(type_string): return "nightclub" elif type_string == "guest house": return "guesthouse" - elif type_string == "concert hall": + elif type_string == "concert hall" and extended: return "concerthall" return type_string diff --git a/mwzeval/utils.py b/mwzeval/utils.py index 726259c..256b4ac 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -205,7 +205,10 @@ def parse_state(turn): def load_multiwoz24(): def is_filled(slot_value: str) -> bool: - """Whether a slot value is filled.""" + """Whether a slot value is filled. + + Unfilled slots should be dropped, as in MultiWOZ 2.2. + """ slot_value = slot_value.lower() return slot_value and slot_value != "not mentioned" and slot_value != "none" @@ -287,7 +290,7 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: parsed_turns.append({"response": "", "state": state}) mwz24_data[dialogue_id.split(".")[0].lower()] = parsed_turns - normalize_data(mwz24_data) + normalize_data(mwz24_data, extended=True) references, states = {}, {} for dialog in mwz24_data: From cf3eeb625e7b5253a5f9963f2400d7e01de00bcb Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Fri, 24 Feb 2023 00:14:18 +0000 Subject: [PATCH 05/11] Revert "Add flag for extended normalization" This reverts commit 1abd9f59cf35e670b655f04d25c4c0d88d45f4c4. --- mwzeval/metrics.py | 12 ++---------- mwzeval/normalization.py | 17 +++++------------ mwzeval/utils.py | 7 ++----- 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index f849129..34bcee3 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -35,13 +35,8 @@ def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = Fa if dst: self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION) - def evaluate(self, input_data, use_extended_normalization=False): - """Get evaluations. - - The use_extended_normalization flag enables a more thorough normalization. - Setting the flag to False ensures backwards compatibility. - """ - normalize_data(input_data, extended=use_extended_normalization) + def evaluate(self, input_data): + normalize_data(input_data) return { "bleu" : get_bleu(input_data, self.reference_dialogs) if self.bleu else None, "success" : get_success(input_data, self.database, self.goals, self.booked_domains) if self.success else None, @@ -58,9 +53,6 @@ def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False) raise NotImplementedError("bleu, success or richness metrics are not yet implemented for MultiWOZ 2.4.") super().__init__(bleu=bleu, success=success, richness=richness, dst=dst) - def evaluate(self, input_data): - return super().evaluate(input_data, use_extended_normalization=True) - def get_bleu(input_data, reference_dialogs): """ Get SacreBLEU score between normalized utterances in input data and a set of normalized references. """ diff --git a/mwzeval/normalization.py b/mwzeval/normalization.py index 461091b..538d7be 100644 --- a/mwzeval/normalization.py +++ b/mwzeval/normalization.py @@ -4,12 +4,8 @@ from sacremoses import MosesTokenizer, MosesDetokenizer -def normalize_data(input_data, extended=False): - """ In-place normalization of raw dictionary with input data. Normalize slot names, slot values, remove plurals and detokenize utterances. - - The extended flag is used for a more thorough normalization compared to MultiWOZ 2.2. - Setting the flag to False ensures backwards compatibility. - """ +def normalize_data(input_data): + """ In-place normalization of raw dictionary with input data. Normalize slot names, slot values, remove plurals and detokenize utterances. """ mt, md = MosesTokenizer(lang='en'), MosesDetokenizer(lang='en') slot_name_re = re.compile(r'\[([\w\s\d]+)\](es|s|-s|-es|)') @@ -29,7 +25,7 @@ def normalize_data(input_data, extended=False): slot = slot.lower().replace(' ', '') if slot == "arriveby": slot = "arrive" elif slot == "leaveat": slot = "leave" - new_state[slot] = normalize_state_slot_value(slot, value, extended=extended) + new_state[slot] = normalize_state_slot_value(slot, value) turn["state"][domain] = new_state @@ -65,15 +61,12 @@ def normalize_slot_name(slot_name): return reverse_slot_name_mapping[slot_name] -def normalize_state_slot_value(slot_name, value, extended=False): +def normalize_state_slot_value(slot_name, value): """ Normalize slot value: 1) replace too distant venue names with canonical values 2) replace too distant food types with canonical values 3) parse time strings to the HH:MM format 4) resolve inconsistency between the database entries and parking and internet slots - - The extended flag is used for a more thorough normalization compared to MultiWOZ 2.2. - Setting the flag to False ensures backwards compatibility. """ def type_to_canonical(type_string): @@ -85,7 +78,7 @@ def type_to_canonical(type_string): return "nightclub" elif type_string == "guest house": return "guesthouse" - elif type_string == "concert hall" and extended: + elif type_string == "concert hall": return "concerthall" return type_string diff --git a/mwzeval/utils.py b/mwzeval/utils.py index 256b4ac..726259c 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -205,10 +205,7 @@ def parse_state(turn): def load_multiwoz24(): def is_filled(slot_value: str) -> bool: - """Whether a slot value is filled. - - Unfilled slots should be dropped, as in MultiWOZ 2.2. - """ + """Whether a slot value is filled.""" slot_value = slot_value.lower() return slot_value and slot_value != "not mentioned" and slot_value != "none" @@ -290,7 +287,7 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: parsed_turns.append({"response": "", "state": state}) mwz24_data[dialogue_id.split(".")[0].lower()] = parsed_turns - normalize_data(mwz24_data, extended=True) + normalize_data(mwz24_data) references, states = {}, {} for dialog in mwz24_data: From 55fb7c26a7b6ecc6d62b7a068c1f890bc9e3f2e4 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Sun, 26 Feb 2023 20:12:42 +0000 Subject: [PATCH 06/11] Add option to disable normalization --- mwzeval/metrics.py | 24 ++++++++++++++++++------ mwzeval/utils.py | 44 +++++++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index 34bcee3..95e53c8 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -18,14 +18,25 @@ class Evaluator: _MWZ_VERSION = '22' - def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = False): + def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False, enable_normalization: bool = True): + """Initialize the evaluator. + + Args: + bleu (bool): Whether to include BLEU metric. + success (bool): Whether to include Inform & Success rates metrics. + richness (bool): Whether to include lexical richness metric. + dst (bool, optional): Whether to include DST metrics. Defaults to False. + enable_normalization (bool, optional): Whether to use slot name and value normalization. Defaults to True. + """ self.bleu = bleu self.success = success self.richness = richness self.dst = dst + + self._enable_normalization = enable_normalization if bleu: - self.reference_dialogs = load_references() + self.reference_dialogs = load_references(enable_normalization=self._enable_normalization) if success: self.database = MultiWOZVenueDatabase() @@ -33,10 +44,11 @@ def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = Fa self.booked_domains = load_booked_domains() if dst: - self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION) + self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION, enable_normalization=self._enable_normalization) def evaluate(self, input_data): - normalize_data(input_data) + if self._enable_normalization: + normalize_data(input_data) return { "bleu" : get_bleu(input_data, self.reference_dialogs) if self.bleu else None, "success" : get_success(input_data, self.database, self.goals, self.booked_domains) if self.success else None, @@ -48,10 +60,10 @@ def evaluate(self, input_data): class Multiwoz24Evaluator(Evaluator): _MWZ_VERSION = '24' - def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False): + def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False, enable_normalization: bool = True): if bleu or success or richness: raise NotImplementedError("bleu, success or richness metrics are not yet implemented for MultiWOZ 2.4.") - super().__init__(bleu=bleu, success=success, richness=richness, dst=dst) + super().__init__(bleu=bleu, success=success, richness=richness, dst=dst, enable_normalization=enable_normalization) def get_bleu(input_data, reference_dialogs): diff --git a/mwzeval/utils.py b/mwzeval/utils.py index 726259c..cb12c98 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -83,7 +83,7 @@ def load_booked_domains(): return json.load(f) -def load_references(systems=['mwz22']): #, 'damd', 'uniconv', 'hdsa', 'lava', 'augpt']): +def load_references(systems=['mwz22'], enable_normalization: bool = True): #, 'damd', 'uniconv', 'hdsa', 'lava', 'augpt']): references = {} for system in systems: if system == 'mwz22': @@ -92,36 +92,36 @@ def load_references(systems=['mwz22']): #, 'damd', 'uniconv', 'hdsa', 'lava', 'a with open(os.path.join(dir_path, "data", "references", f"{system}.json")) as f: references[system] = json.load(f) if 'mwz22' in systems: - references['mwz22'] = load_multiwoz22_reference() + references['mwz22'] = load_multiwoz22_reference(enable_normalization=enable_normalization) return references -def load_multiwoz22_reference(): +def load_multiwoz22_reference(enable_normalization: bool = True): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path = os.path.join(dir_path, "data", "references", "mwz22.json") + data_path = os.path.join(dir_path, "data", "references", "mwz22.json" if enable_normalization else "mwz22_not_normalized.json") if os.path.exists(data_path): with open(data_path) as f: return json.load(f) - references, _ = load_multiwoz22() + references, _ = load_multiwoz22(enable_normalization=enable_normalization) return references -def load_gold_states(mwz_version: Literal['22', '24'] = '22'): +def load_gold_states(mwz_version: Literal['22', '24'] = '22', enable_normalization: bool = True): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_path = os.path.join(dir_path, "data", "gold_states.json") + data_path = os.path.join(dir_path, "data", f"gold_states{mwz_version}.json" if enable_normalization else f"gold_states{mwz_version}_not_normalized.json") if os.path.exists(data_path): with open(data_path) as f: return json.load(f) if mwz_version == "22": - _, states = load_multiwoz22() + _, states = load_multiwoz22(enable_normalization=enable_normalization) elif mwz_version == "24": - _, states = load_multiwoz24() + _, states = load_multiwoz24(enable_normalization=enable_normalization) else: raise ValueError("Unsupported MultiWOZ version.") return states -def load_multiwoz22(): +def load_multiwoz22(enable_normalization: bool = True): def delexicalize_utterance(utterance, span_info): span_info.sort(key=(lambda x: x[-2])) # sort spans by start index @@ -183,7 +183,8 @@ def parse_state(turn): }) mwz22_data[dialog["dialogue_id"].split('.')[0].lower()] = parsed_turns - normalize_data(mwz22_data) + if enable_normalization: + normalize_data(mwz22_data) references, states = {}, {} for dialog in mwz22_data: @@ -191,8 +192,8 @@ def parse_state(turn): states[dialog] = [x["state"] for x in mwz22_data[dialog]] dir_path = os.path.dirname(os.path.realpath(__file__)) - reference_path = os.path.join(dir_path, "data", "references", "mwz22.json") - state_path = os.path.join(dir_path, "data", "gold_states.json") + reference_path = os.path.join(dir_path, "data", "references", "mwz22.json" if enable_normalization else "mwz22_not_normalized.json") + state_path = os.path.join(dir_path, "data", "gold_states22.json" if enable_normalization else "gold_states22_not_normalized.json") with open(reference_path, 'w+') as f: json.dump(references, f, indent=2) @@ -203,9 +204,12 @@ def parse_state(turn): return references, states -def load_multiwoz24(): +def load_multiwoz24(enable_normalization: bool = True): def is_filled(slot_value: str) -> bool: - """Whether a slot value is filled.""" + """Whether a slot value is filled. + + Unfilled slots should be dropped, as in MultiWOZ 2.2. + """ slot_value = slot_value.lower() return slot_value and slot_value != "not mentioned" and slot_value != "none" @@ -250,7 +254,8 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: for item_dict in v: new_states = { (f"book{slot_name}" if prepend_book else slot_name): slot_val - for slot_name, slot_val in item_dict.items() } + for slot_name, slot_val in item_dict.items() + } domain_dial_state.update(new_states) if isinstance(v, str) and v: slot_name = f"book{k}" if prepend_book else k @@ -287,7 +292,8 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: parsed_turns.append({"response": "", "state": state}) mwz24_data[dialogue_id.split(".")[0].lower()] = parsed_turns - normalize_data(mwz24_data) + if enable_normalization: + normalize_data(mwz24_data) references, states = {}, {} for dialog in mwz24_data: @@ -295,8 +301,8 @@ def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]: states[dialog] = [x["state"] for x in mwz24_data[dialog]] dir_path = os.path.dirname(os.path.realpath(__file__)) - # reference_path = os.path.join(dir_path, "data", "references", "mwz24.json") - state_path = os.path.join(dir_path, "data", "gold_states.json") + # reference_path = os.path.join(dir_path, "data", "references", "mwz24.json" if enable_normalization else "mwz24_not_normalized.json") + state_path = os.path.join(dir_path, "data", "gold_states24.json" if enable_normalization else "gold_states24_not_normalized.json") # with open(reference_path, 'w+') as f: # json.dump(references, f, indent=2) From fbedcfa930722105043cc51db62daa7e924fc27e Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Fri, 3 Mar 2023 17:40:40 +0000 Subject: [PATCH 07/11] Add DST domain leave-one-out cross-validation support --- mwzeval/metrics.py | 109 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index 95e53c8..026db74 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -1,7 +1,7 @@ import sys import math -from collections import Counter +from collections import Counter, defaultdict from sacrebleu import corpus_bleu from lexical_diversity import lex_div as ld from fuzzywuzzy import fuzz @@ -46,16 +46,24 @@ def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False, if dst: self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION, enable_normalization=self._enable_normalization) - def evaluate(self, input_data): + def evaluate(self, input_data: dict, include_loocv_metrics: bool = False): + """ + + Args: + input_data (dict): + include_loocv_metrics (bool, optional): Whether to include the leave-one-out cross validation metrics, + currently only supporting DST evaluation. Defaults to False. + """ if self._enable_normalization: normalize_data(input_data) + return { "bleu" : get_bleu(input_data, self.reference_dialogs) if self.bleu else None, "success" : get_success(input_data, self.database, self.goals, self.booked_domains) if self.success else None, "richness" : get_richness(input_data) if self.richness else None, - "dst" : get_dst(input_data, self.gold_states) if self.dst else None, + "dst" : get_dst(input_data, self.gold_states, include_loocv_metrics) if self.dst else None, } - + class Multiwoz24Evaluator(Evaluator): _MWZ_VERSION = '24' @@ -277,8 +285,45 @@ def get_dialog_success(goal, booked_domains, utterances, states, domain_estimate return match, success -def get_dst(input_data, reference_states, fuzzy_ratio=95): - """ Get dialog state tracking results: joint accuracy (exact state match), slot F1, precision and recall """ +def get_dst(input_data, reference_states, include_loocv_metrics=False, fuzzy_ratio=95): + """ Get dialog state tracking results: joint accuracy (exact state match), slot F1, precision and recall + + Note that for each dialogue, the number of turns in the input data should match the reference. + This means when doing leave-one-out cross-valiation, the model should be decoded on the full test set. + """ + DOMAINS = {"hotel", "train", "restaurant", "attraction", "taxi"} + + def block_domains(input_states: dict, reference_states: dict, blocked_domains: set[str]) -> dict: + """Return new input and reference state dictionaries with the specified domains removed. + + Turns with slots from only the blocked domains are removed entirely, otherwise, the slots from the blocked + domains are removed from the turn. + """ + new_input_states = defaultdict(list) + new_ref_states = defaultdict(list) + for dial_id, turns in input_states.items(): + for turn, turn_ref in zip(turns, reference_states[dial_id]): + # drop the blocked slots from the reference state + new_turn_ref = {} + for domain, slot_values in turn_ref.items(): + if domain not in blocked_domains: + new_turn_ref[domain] = slot_values + + # if the reference state does not contain any unblocked slot, + # drop the turn entirely from both input and reference states + if len(new_turn_ref) == 0: + continue + new_ref_states[dial_id].append(new_turn_ref) + + # drop the blocked slots from the input state + new_turn = {} + for domain, slot_values in turn.items(): + if domain not in blocked_domains: + new_turn[domain] = slot_values + # inlcude input state even if it does not contain any unblocked slot, + # which happens when the model wrongly omits slots + new_input_states[dial_id].append(new_turn) + return new_input_states, new_ref_states def flatten(state_dict): constraints = {} @@ -313,18 +358,14 @@ def compare(hyp, ref): fn += 1 return tp, fp, fn - joint_match, slot_acc, slot_f1, slot_p, slot_r = 0, 0, 0, 0, 0 - - if not has_state_predictions(input_data): - sys.stderr.write('error: Missing state predictions!\n') - - else: + def compute_dst_metrics(input_states, reference_states): + joint_match, slot_acc, slot_f1, slot_p, slot_r = 0, 0, 0, 0, 0 total_tp, total_fp, total_fn = 0, 0, 0 num_turns = 0 - for dialog_id in input_data: - for i, turn in enumerate(input_data[dialog_id]): + for dialog_id in input_states: + for i, turn in enumerate(input_states[dialog_id]): ref = flatten(reference_states[dialog_id][i]) - hyp = flatten(turn['state']) + hyp = flatten(turn) if is_matching(hyp, ref): joint_match += 1 @@ -341,9 +382,35 @@ def compare(hyp, ref): slot_f1 = 2 * slot_p * slot_r / (slot_p + slot_r + 1e-10) * 100 joint_match = joint_match / (num_turns + 1e-10) * 100 - return { - 'joint_accuracy' : joint_match, - 'slot_f1' : slot_f1, - 'slot_precision' : slot_p, - 'slot_recall' : slot_r - } + return { + 'joint_accuracy' : joint_match, + 'slot_f1' : slot_f1, + 'slot_precision' : slot_p, + 'slot_recall' : slot_r + } + + if not has_state_predictions(input_data): + sys.stderr.write('error: Missing state predictions!\n') + return + + input_states = defaultdict(list) + for dial_id, turn_infos in input_data.items(): + for turn_info in turn_infos: + input_states[dial_id].append(turn_info["state"]) + metrics = compute_dst_metrics(input_states, reference_states) + + if include_loocv_metrics: + for left_out_domain in DOMAINS: + metrics.update({ + f"only_{left_out_domain}": compute_dst_metrics( + *block_domains(input_states, reference_states, DOMAINS - {left_out_domain}) + ) + }) + for blocked_domain in DOMAINS: + metrics.update({ + f"except_{blocked_domain}": compute_dst_metrics( + *block_domains(input_states, reference_states, set([blocked_domain])) + ) + }) + + return metrics From 4f2d2b0ab0de1cd8499e0d567574724b67abf818 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Sun, 5 Mar 2023 18:10:14 +0000 Subject: [PATCH 08/11] Fix block_domain --- mwzeval/metrics.py | 18 +++++++++--------- mwzeval/utils.py | 8 ++++++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index 026db74..f8c3b5b 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -293,11 +293,11 @@ def get_dst(input_data, reference_states, include_loocv_metrics=False, fuzzy_rat """ DOMAINS = {"hotel", "train", "restaurant", "attraction", "taxi"} - def block_domains(input_states: dict, reference_states: dict, blocked_domains: set[str]) -> dict: + def block_domains(input_states: dict, reference_states: dict, included_domains: set[str]) -> dict: """Return new input and reference state dictionaries with the specified domains removed. - Turns with slots from only the blocked domains are removed entirely, otherwise, the slots from the blocked - domains are removed from the turn. + Turns with no slots from the included domains are removed entirely, otherwise, only the slots from the included + domains are included in the turn (i.e. the blocked slots will be dropped). """ new_input_states = defaultdict(list) new_ref_states = defaultdict(list) @@ -306,10 +306,10 @@ def block_domains(input_states: dict, reference_states: dict, blocked_domains: s # drop the blocked slots from the reference state new_turn_ref = {} for domain, slot_values in turn_ref.items(): - if domain not in blocked_domains: + if domain in included_domains: new_turn_ref[domain] = slot_values - # if the reference state does not contain any unblocked slot, + # if the reference state does not contain any slot from the included domain, # drop the turn entirely from both input and reference states if len(new_turn_ref) == 0: continue @@ -318,9 +318,9 @@ def block_domains(input_states: dict, reference_states: dict, blocked_domains: s # drop the blocked slots from the input state new_turn = {} for domain, slot_values in turn.items(): - if domain not in blocked_domains: + if domain in included_domains: new_turn[domain] = slot_values - # inlcude input state even if it does not contain any unblocked slot, + # inlcude input state even if it does not contain any slot from the included domain, # which happens when the model wrongly omits slots new_input_states[dial_id].append(new_turn) return new_input_states, new_ref_states @@ -403,13 +403,13 @@ def compute_dst_metrics(input_states, reference_states): for left_out_domain in DOMAINS: metrics.update({ f"only_{left_out_domain}": compute_dst_metrics( - *block_domains(input_states, reference_states, DOMAINS - {left_out_domain}) + *block_domains(input_states, reference_states, {left_out_domain}) ) }) for blocked_domain in DOMAINS: metrics.update({ f"except_{blocked_domain}": compute_dst_metrics( - *block_domains(input_states, reference_states, set([blocked_domain])) + *block_domains(input_states, reference_states, DOMAINS - {blocked_domain}) ) }) diff --git a/mwzeval/utils.py b/mwzeval/utils.py index cb12c98..6ad8925 100644 --- a/mwzeval/utils.py +++ b/mwzeval/utils.py @@ -205,13 +205,17 @@ def parse_state(turn): def load_multiwoz24(enable_normalization: bool = True): - def is_filled(slot_value: str) -> bool: + def is_filled(slot_value: str, consider_none_as_filled: bool = False) -> bool: """Whether a slot value is filled. Unfilled slots should be dropped, as in MultiWOZ 2.2. + + Args: + slot_value: the value to check + consider_none_as_filled: whether slots with value "none" should be considered as filled """ slot_value = slot_value.lower() - return slot_value and slot_value != "not mentioned" and slot_value != "none" + return slot_value and slot_value != "not mentioned" and (slot_value != "none" or consider_none_as_filled) def get_first_value(values: str) -> str: """Get the first value if the values string contains multiple.""" From 6b2ca508be3c54b03a4bf637912d8487850ed0ac Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Wed, 15 Mar 2023 16:53:22 +0000 Subject: [PATCH 09/11] Refactor and throw warnings --- mwzeval/metrics.py | 70 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index f8c3b5b..dc06d98 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -286,23 +286,47 @@ def get_dialog_success(goal, booked_domains, utterances, states, domain_estimate def get_dst(input_data, reference_states, include_loocv_metrics=False, fuzzy_ratio=95): - """ Get dialog state tracking results: joint accuracy (exact state match), slot F1, precision and recall + """Get dialog state tracking results: joint accuracy (exact state match), slot F1, precision and recall. + + The input data should have the following format + { + "xxx0000" : [ + { + "state": { + $domain : { + $slot_name: $slot_value + }, ... + }, + ... + }, ... + ], ... + } Note that for each dialogue, the number of turns in the input data should match the reference. This means when doing leave-one-out cross-valiation, the model should be decoded on the full test set. """ DOMAINS = {"hotel", "train", "restaurant", "attraction", "taxi"} - def block_domains(input_states: dict, reference_states: dict, included_domains: set[str]) -> dict: - """Return new input and reference state dictionaries with the specified domains removed. + def filter_inputs_and_references(input_states: dict, reference_states: dict, included_domains: set[str]) -> dict: + """Filter input and reference states to only include states from included_domains. + + This is useful for evaluating in the leave-one-out setup where the joint goal accuracy should be computed + i) jointly with respect to the "left out" (aka unseen) domain and + ii) jointly with respect to all other domains. - Turns with no slots from the included domains are removed entirely, otherwise, only the slots from the included - domains are included in the turn (i.e. the blocked slots will be dropped). + Turns whose references do not contain any slots from included_domains are dropped. """ new_input_states = defaultdict(list) new_ref_states = defaultdict(list) - for dial_id, turns in input_states.items(): - for turn, turn_ref in zip(turns, reference_states[dial_id]): + for dial_id, turn_hyps in input_states.items(): + turn_refs = reference_states[dial_id] + if len(turn_hyps) != len(turn_refs): + sys.stderr.write( + f"error: {dial_id} has {len(turn_hyps)} hypothesis (input) turns," + f" but the reference contains {len(turn_refs)} turns.\n" + ) + + for turn_hyp, turn_ref in zip(turn_hyps, turn_refs): # drop the blocked slots from the reference state new_turn_ref = {} for domain, slot_values in turn_ref.items(): @@ -316,14 +340,17 @@ def block_domains(input_states: dict, reference_states: dict, included_domains: new_ref_states[dial_id].append(new_turn_ref) # drop the blocked slots from the input state - new_turn = {} - for domain, slot_values in turn.items(): + new_turn_hyp = {} + for domain, slot_values in turn_hyp.items(): if domain in included_domains: - new_turn[domain] = slot_values - # inlcude input state even if it does not contain any slot from the included domain, + new_turn_hyp[domain] = slot_values + # inlcude input state even if it does not contain any unblocked slot, # which happens when the model wrongly omits slots - new_input_states[dial_id].append(new_turn) - return new_input_states, new_ref_states + new_input_states[dial_id].append(new_turn_hyp) + + assert len(new_input_states[dial_id]) == len(new_ref_states[dial_id]) + + return dict(new_input_states), dict(new_ref_states) def flatten(state_dict): constraints = {} @@ -363,8 +390,16 @@ def compute_dst_metrics(input_states, reference_states): total_tp, total_fp, total_fn = 0, 0, 0 num_turns = 0 for dialog_id in input_states: - for i, turn in enumerate(input_states[dialog_id]): - ref = flatten(reference_states[dialog_id][i]) + hyps = input_states[dialog_id] + refs = reference_states[dialog_id] + if len(hyps) != len(refs): + sys.stderr.write( + f"warning: {dialog_id} has {len(hyps)} hypothesis (input) turns," + f" but the reference contains {len(refs)} turns." + " If this is intented, please make sure that turns are dropped from the end of the dialogue.\n" + ) + for i, turn in enumerate(hyps): + ref = flatten(refs[i]) hyp = flatten(turn) if is_matching(hyp, ref): @@ -397,19 +432,20 @@ def compute_dst_metrics(input_states, reference_states): for dial_id, turn_infos in input_data.items(): for turn_info in turn_infos: input_states[dial_id].append(turn_info["state"]) + input_states = dict(input_states) metrics = compute_dst_metrics(input_states, reference_states) if include_loocv_metrics: for left_out_domain in DOMAINS: metrics.update({ f"only_{left_out_domain}": compute_dst_metrics( - *block_domains(input_states, reference_states, {left_out_domain}) + *filter_inputs_and_references(input_states, reference_states, {left_out_domain}) ) }) for blocked_domain in DOMAINS: metrics.update({ f"except_{blocked_domain}": compute_dst_metrics( - *block_domains(input_states, reference_states, DOMAINS - {blocked_domain}) + *filter_inputs_and_references(input_states, reference_states, DOMAINS - {blocked_domain}) ) }) From 07ec3a86c883f638d71d0f4e6a5bfc44cd96378e Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Fri, 17 Mar 2023 23:41:18 +0000 Subject: [PATCH 10/11] Fix bug dropping turns with empty reference states --- mwzeval/metrics.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index dc06d98..d223412 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -333,10 +333,19 @@ def filter_inputs_and_references(input_states: dict, reference_states: dict, inc if domain in included_domains: new_turn_ref[domain] = slot_values - # if the reference state does not contain any slot from the included domain, - # drop the turn entirely from both input and reference states - if len(new_turn_ref) == 0: + # for a given turn, if its reference state does not contain any + # slot from the included domains, and the turn is not the first + # turn of a series of turns involving the included domain (when + # the dialogue initiates or domain switching occurs, and no state + # has been mentioned, drop the turn entirely from both input and + # reference states and clear states of previous turns from the + # same dialogue + if len(new_turn_ref) == 0 and len(turn_ref) != 0: + if len(new_ref_states[dial_id]) > 0: + new_ref_states[dial_id] = [] + new_input_states[dial_id] = [] continue + new_ref_states[dial_id].append(new_turn_ref) # drop the blocked slots from the input state @@ -349,6 +358,9 @@ def filter_inputs_and_references(input_states: dict, reference_states: dict, inc new_input_states[dial_id].append(new_turn_hyp) assert len(new_input_states[dial_id]) == len(new_ref_states[dial_id]) + if all(map(lambda turn: len(turn) == 0, new_ref_states[dial_id])): + new_input_states[dial_id] = [] + new_ref_states[dial_id] = [] return dict(new_input_states), dict(new_ref_states) From 1e244babbb61458b130fa262e706e37834704c03 Mon Sep 17 00:00:00 2001 From: Weixuan Zhang Date: Mon, 17 Jul 2023 08:20:01 +0100 Subject: [PATCH 11/11] Fix empty ref pruning --- mwzeval/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mwzeval/metrics.py b/mwzeval/metrics.py index d223412..2f5402e 100644 --- a/mwzeval/metrics.py +++ b/mwzeval/metrics.py @@ -341,7 +341,7 @@ def filter_inputs_and_references(input_states: dict, reference_states: dict, inc # reference states and clear states of previous turns from the # same dialogue if len(new_turn_ref) == 0 and len(turn_ref) != 0: - if len(new_ref_states[dial_id]) > 0: + if len(new_ref_states[dial_id]) > 0 and all(map(lambda x: len(x) == 0, new_ref_states[dial_id])): new_ref_states[dial_id] = [] new_input_states[dial_id] = [] continue