Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiWOZ 2.4 DST evaluation with leave-one-out cross-validation support #18

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
32 changes: 27 additions & 5 deletions mwzeval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,54 @@


class Evaluator:

def __init__(self, bleu : bool, success : bool, richness : bool, dst : bool = False):
_MWZ_VERSION = '22'

def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False, enable_normalization: bool = True):
"""Initialize the evaluator.

Args:
bleu (bool): Whether to include BLEU metric.
success (bool): Whether to include Inform & Success rates metrics.
richness (bool): Whether to include lexical richness metric.
dst (bool, optional): Whether to include DST metrics. Defaults to False.
enable_normalization (bool, optional): Whether to use slot name and value normalization. Defaults to True.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeixuanZ you should state what normalisation is applied here (e.g. "same normalisation as per the 22 version") or something along these lines.

"""
self.bleu = bleu
self.success = success
self.richness = richness
self.dst = dst

self._enable_normalization = enable_normalization

if bleu:
self.reference_dialogs = load_references()
self.reference_dialogs = load_references(enable_normalization=self._enable_normalization)

if success:
self.database = MultiWOZVenueDatabase()
self.goals = load_goals()
self.booked_domains = load_booked_domains()

if dst:
self.gold_states = load_gold_states()
self.gold_states = load_gold_states(mwz_version=self._MWZ_VERSION, enable_normalization=self._enable_normalization)

def evaluate(self, input_data):
normalize_data(input_data)
if self._enable_normalization:
normalize_data(input_data)
return {
"bleu" : get_bleu(input_data, self.reference_dialogs) if self.bleu else None,
"success" : get_success(input_data, self.database, self.goals, self.booked_domains) if self.success else None,
"richness" : get_richness(input_data) if self.richness else None,
"dst" : get_dst(input_data, self.gold_states) if self.dst else None,
}


class Multiwoz24Evaluator(Evaluator):
_MWZ_VERSION = '24'

def __init__(self, bleu: bool, success: bool, richness: bool, dst: bool = False, enable_normalization: bool = True):
if bleu or success or richness:
raise NotImplementedError("bleu, success or richness metrics are not yet implemented for MultiWOZ 2.4.")
super().__init__(bleu=bleu, success=success, richness=richness, dst=dst, enable_normalization=enable_normalization)


def get_bleu(input_data, reference_dialogs):
Expand Down
2 changes: 2 additions & 0 deletions mwzeval/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def type_to_canonical(type_string):
return "nightclub"
elif type_string == "guest house":
return "guesthouse"
elif type_string == "concert hall":
return "concerthall"
return type_string

def name_to_canonical(name, domain=None):
Expand Down
143 changes: 131 additions & 12 deletions mwzeval/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import json
import urllib.request
from typing import Literal
import zipfile
import io
from collections import defaultdict

from mwzeval.normalization import normalize_data

Expand Down Expand Up @@ -79,7 +83,7 @@ def load_booked_domains():
return json.load(f)


def load_references(systems=['mwz22']): #, 'damd', 'uniconv', 'hdsa', 'lava', 'augpt']):
def load_references(systems=['mwz22'], enable_normalization: bool = True): #, 'damd', 'uniconv', 'hdsa', 'lava', 'augpt']):
references = {}
for system in systems:
if system == 'mwz22':
Expand All @@ -88,31 +92,36 @@ def load_references(systems=['mwz22']): #, 'damd', 'uniconv', 'hdsa', 'lava', 'a
with open(os.path.join(dir_path, "data", "references", f"{system}.json")) as f:
references[system] = json.load(f)
if 'mwz22' in systems:
references['mwz22'] = load_multiwoz22_reference()
references['mwz22'] = load_multiwoz22_reference(enable_normalization=enable_normalization)
return references


def load_multiwoz22_reference():
def load_multiwoz22_reference(enable_normalization: bool = True):
dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(dir_path, "data", "references", "mwz22.json")
data_path = os.path.join(dir_path, "data", "references", "mwz22.json" if enable_normalization else "mwz22_not_normalized.json")
if os.path.exists(data_path):
with open(data_path) as f:
return json.load(f)
references, _ = load_multiwoz22()
references, _ = load_multiwoz22(enable_normalization=enable_normalization)
return references


def load_gold_states():
def load_gold_states(mwz_version: Literal['22', '24'] = '22', enable_normalization: bool = True):
dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(dir_path, "data", "gold_states.json")
data_path = os.path.join(dir_path, "data", f"gold_states{mwz_version}.json" if enable_normalization else f"gold_states{mwz_version}_not_normalized.json")
if os.path.exists(data_path):
with open(data_path) as f:
return json.load(f)
_, states = load_multiwoz22()
if mwz_version == "22":
_, states = load_multiwoz22(enable_normalization=enable_normalization)
elif mwz_version == "24":
_, states = load_multiwoz24(enable_normalization=enable_normalization)
else:
raise ValueError("Unsupported MultiWOZ version.")
return states


def load_multiwoz22():
def load_multiwoz22(enable_normalization: bool = True):

def delexicalize_utterance(utterance, span_info):
span_info.sort(key=(lambda x: x[-2])) # sort spans by start index
Expand Down Expand Up @@ -174,16 +183,17 @@ def parse_state(turn):
})
mwz22_data[dialog["dialogue_id"].split('.')[0].lower()] = parsed_turns

normalize_data(mwz22_data)
if enable_normalization:
normalize_data(mwz22_data)

references, states = {}, {}
for dialog in mwz22_data:
references[dialog] = [x["response"] for x in mwz22_data[dialog]]
states[dialog] = [x["state"] for x in mwz22_data[dialog]]

dir_path = os.path.dirname(os.path.realpath(__file__))
reference_path = os.path.join(dir_path, "data", "references", "mwz22.json")
state_path = os.path.join(dir_path, "data", "gold_states.json")
reference_path = os.path.join(dir_path, "data", "references", "mwz22.json" if enable_normalization else "mwz22_not_normalized.json")
state_path = os.path.join(dir_path, "data", "gold_states22.json" if enable_normalization else "gold_states22_not_normalized.json")

with open(reference_path, 'w+') as f:
json.dump(references, f, indent=2)
Expand All @@ -192,3 +202,112 @@ def parse_state(turn):
json.dump(states, f, indent=2)

return references, states


def load_multiwoz24(enable_normalization: bool = True):
def is_filled(slot_value: str) -> bool:
"""Whether a slot value is filled.

Unfilled slots should be dropped, as in MultiWOZ 2.2.
"""
slot_value = slot_value.lower()
return slot_value and slot_value != "not mentioned" and slot_value != "none"

def get_first_value(values: str) -> str:
"""Get the first value if the values string contains multiple."""
if "|" in values:
values = values.split("|")
elif ">" in values:
values = values.split(">")
elif "<" in values:
values = values.split("<")
else:
values = [values]
return values[0]

def parse_state(turn: dict, prepend_book: bool = False) -> dict[dict[str, str]]:
"""Get the slot values of a given turn.

This function is adapted from
google-research/schema_guided_dst/multiwoz/create_data_from_multiwoz.py

If a slot has multiple values (which are separated by '|', '<' or '>'), only the first one is taken.
This is consistant with the approach taken for MultiWOZ 2.2 evaluation.

Args:
turn: Dictionary of a turn of the MultiWOZ 2.4 dataset
prepend_book: Whether to prepend the string 'book' to slot names for booking slots.
MultiWOZ 2.2 has the 'book' prefix.

Returns:
{$domain: {$slot_name: $value, ...}, ...}
"""
dialog_states = defaultdict(dict)
for domain_name, values in turn['metadata'].items():
domain_dial_state = {}

for k, v in values["book"].items():
# Note: "booked" is not really a state, just booking confirmation
if k == 'booked':
continue
if isinstance(v, list):
for item_dict in v:
new_states = {
(f"book{slot_name}" if prepend_book else slot_name): slot_val
for slot_name, slot_val in item_dict.items()
}
domain_dial_state.update(new_states)
if isinstance(v, str) and v:
slot_name = f"book{k}" if prepend_book else k
domain_dial_state[slot_name] = v

new_states = values["semi"]
domain_dial_state.update(new_states)

domain_dial_state = {
slot_name: get_first_value(value) # use the first value
for slot_name, value in domain_dial_state.items()
if is_filled(value)
}
if len(domain_dial_state) > 0:
dialog_states[domain_name] = domain_dial_state

return dialog_states

with urllib.request.urlopen(
"https://github.com/smartyfh/MultiWOZ2.4/blob/main/data/MULTIWOZ2.4.zip?raw=true"
) as url:
print("Downloading MultiWOZ_2.4")
unzipped = zipfile.ZipFile(io.BytesIO(url.read()))
# dialogue_acts = json.loads(unzipped.read('MULTIWOZ2.4/dialogue_acts.json'))
data = json.loads(unzipped.read('MULTIWOZ2.4/data.json'))

mwz24_data = {}
for dialogue_id, dialogue in data.items():
parsed_turns = []
for i, turn in enumerate(dialogue["log"]):
if i % 2 == 0:
continue
state = parse_state(turn)
parsed_turns.append({"response": "", "state": state})
mwz24_data[dialogue_id.split(".")[0].lower()] = parsed_turns

if enable_normalization:
normalize_data(mwz24_data)

references, states = {}, {}
for dialog in mwz24_data:
# references[dialog] = [x["response"] for x in mwz24_data[dialog]]
states[dialog] = [x["state"] for x in mwz24_data[dialog]]

dir_path = os.path.dirname(os.path.realpath(__file__))
# reference_path = os.path.join(dir_path, "data", "references", "mwz24.json" if enable_normalization else "mwz24_not_normalized.json")
state_path = os.path.join(dir_path, "data", "gold_states24.json" if enable_normalization else "gold_states24_not_normalized.json")

# with open(reference_path, 'w+') as f:
# json.dump(references, f, indent=2)

with open(state_path, 'w+') as f:
json.dump(states, f, indent=2)

return references, states