Skip to content

Commit

Permalink
Changing assertMappingEqual to support arbitrary equality function. A…
Browse files Browse the repository at this point in the history
…lso adding assertDictAlmostEqual that uses assertAlmostEqual as equality function for float values.

PiperOrigin-RevId: 730804727
  • Loading branch information
al-bus authored and copybara-github committed Feb 28, 2025
1 parent c98852f commit 1297a00
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 22 deletions.
118 changes: 100 additions & 18 deletions absl/testing/absltest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,15 +1684,89 @@ def assertDictEqual(self, a, b, msg=None):
"""
self.assertMappingEqual(a, b, msg, mapping_type=dict)

def assertDictAlmostEqual(
self,
a,
b,
places=None,
msg=None,
delta=None,
):
"""Raises AssertionError if a and b are not equal or almost equal dicts.
This is like assertDictEqual, except for float values which are compared
using assertAlmostEqual. Almost equality is determined for float values by:
- have numeric difference less than the given delta,
or
- equal if rounded to the given number of decimal places after the decimal
point (default 7).
Args:
a: A dict, the expected value.
b: A dict, the actual value.
places: The number of decimal places to compare for floats.
msg: An optional str, the associated message.
delta: The OK difference between compared values for floats.
Raises:
AssertionError: if the dictionaries are not equal or almost equal.
ValueError: if both places and delta are specified.
"""

# Almost equality with preset places and delta.
def almost_equal_compare(
a_value: Any, b_value: Any
) -> tuple[bool, Optional[BaseException]]:
if isinstance(a_value, float) or isinstance(b_value, float):
try:
# assertAlmostEqual should be called with at most one of `places`
# and `delta`. However, it's okay for assertMappingEqual to pass
# both because we want the latter to fail if the former does.
# pytype: disable=wrong-keyword-args
self.assertAlmostEqual(
a_value,
b_value,
places=places,
delta=delta,
)
# pytype: enable=wrong-keyword-args
except self.failureException as err:
return False, err
return True, None

if delta is not None and places is not None:
raise ValueError('specify delta or places not both\n')

self.assertMappingEqual(
a,
b,
msg,
mapping_type=dict,
check_values_equality=almost_equal_compare,
)

def assertMappingEqual(
self, a, b, msg=None, mapping_type=collections.abc.Mapping
self,
a,
b,
msg=None,
mapping_type=collections.abc.Mapping,
check_values_equality=lambda x, y: (x == y, None),
):
"""Raises AssertionError if a and b are not equal mappings.
"""Raises AssertionError if a and b differ in keys or values.
Key sets must be exectly the same, the corresponding values should satisfy
the provided equality function.
Args:
a: A mapping, the expected value.
b: A mapping, the actual value.
msg: An optional str, the associated message.
mapping_type: The expected type of the mappings.
check_values_equality: A function that takes two values and returns a
tuple of (bool, BaseException), where the bool is True if the values are
equal and the BaseException is an optional exception occured during the
equality check.
Raises:
AssertionError: if the dictionaries are not equal.
Expand All @@ -1710,22 +1784,41 @@ def assertMappingEqual(
f' {type(b).__name__}',
msg,
)
if a == b:
return

def Sorted(list_of_items):
try:
return sorted(list_of_items) # In 3.3, unordered are possible.
except TypeError:
return list_of_items

if a == b:
return
a_items = Sorted(list(a.items()))
b_items = Sorted(list(b.items()))

unexpected = []
missing = []
different = []

# The standard library default output confounds lexical difference with
# value difference; treat them separately.
for a_key, a_value in a_items:
if a_key not in b:
missing.append((a_key, a_value))
continue
b_value = b[a_key]
is_equal, err = check_values_equality(a_value, b_value)
if not is_equal:
different.append((a_key, a_value, b_value, err))

for b_key, b_value in b_items:
if b_key not in a:
unexpected.append((b_key, b_value))

# If all difference buckets are empty, then mappings are considered equal.
if not unexpected and not different and not missing:
return

safe_repr = unittest.util.safe_repr # pytype: disable=module-attr

def Repr(dikt):
Expand All @@ -1737,18 +1830,6 @@ def Repr(dikt):

message = [f'{Repr(a)} != {Repr(b)}{"("+msg+")" if msg else ""}']

# The standard library default output confounds lexical difference with
# value difference; treat them separately.
for a_key, a_value in a_items:
if a_key not in b:
missing.append((a_key, a_value))
elif a_value != b[a_key]:
different.append((a_key, a_value, b[a_key]))

for b_key, b_value in b_items:
if b_key not in a:
unexpected.append((b_key, b_value))

if unexpected:
message.append(
'Unexpected, but present entries:\n'
Expand All @@ -1759,8 +1840,9 @@ def Repr(dikt):
message.append(
'repr() of differing entries:\n'
+ ''.join(
f'{safe_repr(k)}: {safe_repr(a_value)} != {safe_repr(b_value)}\n'
for k, a_value, b_value in different
f'{safe_repr(k)}: '
f'{err if err else f"{safe_repr(a_value)} != {safe_repr(b_value)}"}\n'
for k, a_value, b_value, err in different
)
)

Expand Down
128 changes: 124 additions & 4 deletions absl/testing/tests/absltest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,13 @@ def test_assert_dict_equal_requires_dict(self):
)
def test_assert_dict_equal(self, use_mapping: bool):

def assert_dict_equal(a, b, msg=None):
def assert_dict_equal(a, b, msg=None, places=None, delta=None):
if use_mapping:
self.assertMappingEqual(a, b, msg=msg)
else:
elif places is None and delta is None:
self.assertDictEqual(a, b, msg=msg)
else:
self.assertDictAlmostEqual(a, b, msg=msg, places=places, delta=delta)

assert_dict_equal({}, {})

Expand Down Expand Up @@ -482,7 +484,8 @@ def assert_dict_equal(a, b, msg=None):
try:
assert_dict_equal(expected, seen)
except AssertionError as e:
self.assertMultiLineEqual("""\
self.assertMultiLineEqual(
"""\
{'a': 1, 'b': 2, 'c': 3} != {'a': 2, 'c': 3, 'd': 4}
Unexpected, but present entries:
'd': 4
Expand All @@ -492,7 +495,9 @@ def assert_dict_equal(a, b, msg=None):
Missing entries:
'b': 2
""", str(e))
""",
str(e),
)
else:
self.fail('Expecting AssertionError')

Expand Down Expand Up @@ -618,6 +623,121 @@ def test_assert_set_equal(self):
set2 = {(4, 5)}
self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)

@parameterized.named_parameters(
dict(testcase_name='empty', a={}, b={}),
dict(testcase_name='equal_float', a={'a': 1.01}, b={'a': 1.01}),
dict(testcase_name='int_and_float', a={'a': 0}, b={'a': 0.000_000_01}),
dict(testcase_name='float_and_int', a={'a': 0.000_000_01}, b={'a': 0}),
dict(
testcase_name='mixed_elements',
a={'a': 'A', 'b': 1, 'c': 0.999_999_99},
b={'a': 'A', 'b': 1, 'c': 1},
),
dict(
testcase_name='float_artifacts',
a={'a': 0.15000000000000002},
b={'a': 0.15},
),
dict(
testcase_name='multiple_floats',
a={'a': 1.0, 'b': 2.0},
b={'a': 1.000_000_01, 'b': 1.999_999_99},
),
)
def test_assert_dict_almost_equal(self, a, b):
self.assertDictAlmostEqual(a, b)

@parameterized.named_parameters(
dict(
testcase_name='default_places_is_7',
a={'a': 1.0},
b={'a': 1.000_000_01},
places=None,
delta=None,
),
dict(
testcase_name='places',
a={'a': 1.011},
b={'a': 1.009},
places=2,
delta=None,
),
dict(
testcase_name='delta',
a={'a': 1.00},
b={'a': 1.09},
places=None,
delta=0.1,
),
)
def test_assert_dict_almost_equal_with_tolerance(self, a, b, places, delta):
self.assertDictAlmostEqual(a, b, places=places, delta=delta)

@parameterized.named_parameters(
dict(
testcase_name='default_places_is_7',
a={'a': 1.0},
b={'a': 1.000_000_1},
places=None,
delta=None,
),
dict(
testcase_name='places',
a={'a': 1.001},
b={'a': 1.002},
places=3,
delta=None,
),
dict(
testcase_name='delta',
a={'a': 1.01},
b={'a': 1.02},
places=None,
delta=0.01,
),
)
def test_assert_dict_almost_equal_fails_with_tolerance(
self, a, b, places, delta
):
with self.assertRaises(self.failureException):
self.assertDictAlmostEqual(a, b, places=places, delta=delta)

def test_assert_dict_almost_equal_assertion_message(self):
with self.assertRaises(AssertionError) as e:
self.assertDictAlmostEqual({'a': 0.6}, {'a': 1.0}, delta=0.1)
self.assertMultiLineEqual(
"""\
{'a': 0.6} != {'a': 1.0}
repr() of differing entries:
'a': 0.6 != 1.0 within 0.1 delta (0.4 difference)
""",
str(e.exception),
)

def test_assert_dict_almost_equal_fails_with_custom_message(self):
with self.assertRaises(AssertionError) as e:
self.assertDictAlmostEqual(
{'a': 0.6}, {'a': 1.0}, delta=0.1, msg='custom message'
)
self.assertMultiLineEqual(
"""\
{'a': 0.6} != {'a': 1.0}(custom message)
repr() of differing entries:
'a': 0.6 != 1.0 within 0.1 delta (0.4 difference)
""",
str(e.exception),
)

def test_assert_dict_almost_equal_fails_with_both_places_and_delta(self):
with self.assertRaises(ValueError) as e:
self.assertDictAlmostEqual({'a': 1.0}, {'a': 1.0}, places=2, delta=0.01)
self.assertMultiLineEqual(
"""\
specify delta or places not both
""",
str(e.exception),
)

def test_assert_sequence_almost_equal(self):
actual = (1.1, 1.2, 1.4)

Expand Down

0 comments on commit 1297a00

Please sign in to comment.