Skip to content

Commit

Permalink
Fix mock of tf.summary.scalar so that the mock also covers tf.compat.…
Browse files Browse the repository at this point in the history
…v2.summary.scalar.

PiperOrigin-RevId: 580562164
Change-Id: I948ee6bc9177b23aacfec7d02663e446417a809f
  • Loading branch information
fionalang authored and copybara-github committed Nov 8, 2023
1 parent 7b67210 commit d9256db
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tf_agents/train/utils/train_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
('_mirrored', lambda: tf.distribute.MirroredStrategy(devices=_CPUS)),
)

_SUMMARY_MOCK = mock.MagicMock(autospec=True)


class TrainUtilsTest(parameterized.TestCase, test_utils.TestCase):

Expand Down Expand Up @@ -74,8 +76,8 @@ def test_after_train_step_fn_with_fresh_data_only(self, create_strategy_fn):

# Call the after train function and check the expectations.
with mock.patch.object(
tf.summary, 'scalar', autospec=True
) as mock_scalar_summary:
tf.summary, 'scalar', new=_SUMMARY_MOCK
), mock.patch.object(tf.compat.v2.summary, 'scalar', new=_SUMMARY_MOCK):
# Call the `after_train_function` on the test input. Assumed the
# observation train steps are stored in the field `priority` of the
# the sample info of Reverb.
Expand All @@ -84,7 +86,7 @@ def test_after_train_step_fn_with_fresh_data_only(self, create_strategy_fn):
strategy.run(after_train_step_fn, args=((None, info), None))

# Check if the expected calls happened on the scalar summary.
mock_scalar_summary.assert_has_calls(
_SUMMARY_MOCK.assert_has_calls(
expected_scalar_summary_calls, any_order=False
)

Expand Down Expand Up @@ -123,8 +125,8 @@ def test_after_train_step_fn_with_stale_data(self, create_strategy_fn):

# Call the after train function and check the expectations.
with mock.patch.object(
tf.summary, 'scalar', autospec=True
) as mock_scalar_summary:
tf.summary, 'scalar', new=_SUMMARY_MOCK
), mock.patch.object(tf.compat.v2.summary, 'scalar', new=_SUMMARY_MOCK):
# Call the `after_train_function` on the test input. Assumed the
# observation train steps are stored in the field `priority` of the
# the sample info of Reverb.
Expand All @@ -133,7 +135,7 @@ def test_after_train_step_fn_with_stale_data(self, create_strategy_fn):
strategy.run(after_train_step_fn, args=((None, info), None))

# Check if the expected calls happened on the scalar summary.
mock_scalar_summary.assert_has_calls(
_SUMMARY_MOCK.assert_has_calls(
expected_scalar_summary_calls, any_order=False
)

Expand Down

0 comments on commit d9256db

Please sign in to comment.