diff --git a/tf_agents/train/utils/train_utils_test.py b/tf_agents/train/utils/train_utils_test.py index 46e662009..b648eac2c 100644 --- a/tf_agents/train/utils/train_utils_test.py +++ b/tf_agents/train/utils/train_utils_test.py @@ -32,6 +32,8 @@ ('_mirrored', lambda: tf.distribute.MirroredStrategy(devices=_CPUS)), ) +_SUMMARY_MOCK = mock.MagicMock(autospec=True) + class TrainUtilsTest(parameterized.TestCase, test_utils.TestCase): @@ -74,8 +76,8 @@ def test_after_train_step_fn_with_fresh_data_only(self, create_strategy_fn): # Call the after train function and check the expectations. with mock.patch.object( - tf.summary, 'scalar', autospec=True - ) as mock_scalar_summary: + tf.summary, 'scalar', new=_SUMMARY_MOCK + ), mock.patch.object(tf.compat.v2.summary, 'scalar', new=_SUMMARY_MOCK): # Call the `after_train_function` on the test input. Assumed the # observation train steps are stored in the field `priority` of the # the sample info of Reverb. @@ -84,7 +86,7 @@ def test_after_train_step_fn_with_fresh_data_only(self, create_strategy_fn): strategy.run(after_train_step_fn, args=((None, info), None)) # Check if the expected calls happened on the scalar summary. - mock_scalar_summary.assert_has_calls( + _SUMMARY_MOCK.assert_has_calls( expected_scalar_summary_calls, any_order=False ) @@ -123,8 +125,8 @@ def test_after_train_step_fn_with_stale_data(self, create_strategy_fn): # Call the after train function and check the expectations. with mock.patch.object( - tf.summary, 'scalar', autospec=True - ) as mock_scalar_summary: + tf.summary, 'scalar', new=_SUMMARY_MOCK + ), mock.patch.object(tf.compat.v2.summary, 'scalar', new=_SUMMARY_MOCK): # Call the `after_train_function` on the test input. Assumed the # observation train steps are stored in the field `priority` of the # the sample info of Reverb. @@ -133,7 +135,7 @@ def test_after_train_step_fn_with_stale_data(self, create_strategy_fn): strategy.run(after_train_step_fn, args=((None, info), None)) # Check if the expected calls happened on the scalar summary. - mock_scalar_summary.assert_has_calls( + _SUMMARY_MOCK.assert_has_calls( expected_scalar_summary_calls, any_order=False )