Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiWOZ 2.4 DST evaluation with leave-one-out cross-validation support #18

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
18 changes: 9 additions & 9 deletions mwzeval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ def get_dst(input_data, reference_states, include_loocv_metrics=False, fuzzy_rat
"""
DOMAINS = {"hotel", "train", "restaurant", "attraction", "taxi"}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make some API improvements. Instead of include_loocv_metrics which most users won't understand we can have left_out_domain=None in which case we can return:

  • the joint goal accuracy wrt to all domains - a turn is marked as correct if all states from all domains are predicted correctly. This would be called test-jga.
  • we could also report the JGA wrt to each individual domain under the key [domain_name]_jga. A turn is marked correct if the states from a given domain are all correct. Errors in predicting states in other domains are ignored.
  • we could also report the joint accuracy of all 4 domain combinations, to facilitate comparisons with leave-one-out setting. Here we just ignore the predictions from the left out domain & the turns where only the left out domain appears. We would name the fields except_{domain}_jga.

left_out_domain should be a string that the user can set to one of the 5 domains in DOMAINS and we should assert the input is correct at the very beginning. The keys reported should be:

  • test-jga where the joint accuracy with respect to all domains is computed. This number should be directly comparable with the setting when left_out_domain=None
  • [domain_name]_jga - as before, this is the joint accuracy of each individual domain. The numbers should be comparable with the equivalents when left_out_domain=None.
  • except_{left_out_domain}_jga - this is joint accuracy with respect to all the domains seen in training. If we also report it when left_out_domain=None, then the user sees if the left out domain "helped" improve performance in the other domains or not.

I think this is largely what the current output evaluation returns but we should very carefully and clearly document this to make sure the implementation is correct. We should document the above very clearly in the docstring so that reviewers of the PR who are multiwoz experts can validate our approach in full knowledge of our logic.

def block_domains(input_states: dict, reference_states: dict, blocked_domains: set[str]) -> dict:
def block_domains(input_states: dict, reference_states: dict, included_domains: set[str]) -> dict:
"""Return new input and reference state dictionaries with the specified domains removed.

Turns with slots from only the blocked domains are removed entirely, otherwise, the slots from the blocked
domains are removed from the turn.
Turns with no slots from the included domains are removed entirely, otherwise, only the slots from the included
domains are included in the turn (i.e. the blocked slots will be dropped).
"""
new_input_states = defaultdict(list)
new_ref_states = defaultdict(list)
Expand All @@ -306,10 +306,10 @@ def block_domains(input_states: dict, reference_states: dict, blocked_domains: s
# drop the blocked slots from the reference state
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loop through the references first - we should predict for every turn in the test set and so input_states should have all the dialogues. If something went wrong during parsing or prediction and the user has missed predictions for some dialogues and/or turns, the code should fail. As currently implemented, there will be a silent bug.

new_turn_ref = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an assertion to check that turn and turn_ref lists have the same length

for domain, slot_values in turn_ref.items():
if domain not in blocked_domains:
if domain in included_domains:
new_turn_ref[domain] = slot_values

# if the reference state does not contain any unblocked slot,
# if the reference state does not contain any slot from the included domain,
# drop the turn entirely from both input and reference states
if len(new_turn_ref) == 0:
continue
Expand All @@ -318,9 +318,9 @@ def block_domains(input_states: dict, reference_states: dict, blocked_domains: s
# drop the blocked slots from the input state
new_turn = {}
for domain, slot_values in turn.items():
if domain not in blocked_domains:
if domain in included_domains:
new_turn[domain] = slot_values
# inlcude input state even if it does not contain any unblocked slot,
# inlcude input state even if it does not contain any slot from the included domain,
# which happens when the model wrongly omits slots
new_input_states[dial_id].append(new_turn)
return new_input_states, new_ref_states
Expand Down Expand Up @@ -403,13 +403,13 @@ def compute_dst_metrics(input_states, reference_states):
for left_out_domain in DOMAINS:
metrics.update({
f"only_{left_out_domain}": compute_dst_metrics(
*block_domains(input_states, reference_states, DOMAINS - {left_out_domain})
*block_domains(input_states, reference_states, {left_out_domain})
)
})
for blocked_domain in DOMAINS:
metrics.update({
f"except_{blocked_domain}": compute_dst_metrics(
*block_domains(input_states, reference_states, set([blocked_domain]))
*block_domains(input_states, reference_states, DOMAINS - {blocked_domain})
)
})

Expand Down
8 changes: 6 additions & 2 deletions mwzeval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,17 @@ def parse_state(turn):


def load_multiwoz24(enable_normalization: bool = True):
def is_filled(slot_value: str) -> bool:
def is_filled(slot_value: str, consider_none_as_filled: bool = False) -> bool:
"""Whether a slot value is filled.

Unfilled slots should be dropped, as in MultiWOZ 2.2.

Args:
slot_value: the value to check
consider_none_as_filled: whether slots with value "none" should be considered as filled
"""
slot_value = slot_value.lower()
return slot_value and slot_value != "not mentioned" and slot_value != "none"
return slot_value and slot_value != "not mentioned" and (slot_value != "none" or consider_none_as_filled)

def get_first_value(values: str) -> str:
"""Get the first value if the values string contains multiple."""
Expand Down