Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721137353
  • Loading branch information
tensorflower-gardener committed Jan 29, 2025
1 parent c9cac9e commit 3d5e05f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
23 changes: 12 additions & 11 deletions official/common/distribute_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,23 @@ def test_invalid_args(self):

def test_one_device_strategy_cpu(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertEqual(ds.num_replicas_in_sync, 1)
self.assertEqual(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0])

def test_one_device_strategy_gpu(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertEqual(ds.num_replicas_in_sync, 1)
self.assertEqual(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0])

def test_mirrored_strategy(self):
# CPU only.
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
# 5 GPUs.
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5)
self.assertEqual(ds.num_replicas_in_sync, 5)
self.assertEqual(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices:
self.assertIn('GPU', device)

Expand Down Expand Up @@ -105,12 +105,13 @@ def test_tpu_strategy(self):
ds, tf.distribute.TPUStrategy)

def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
'distribution_strategy must be a string but got: False. If'):
with self.assertRaisesRegex(
ValueError, 'distribution_strategy must be a string but got: False. If'
):
distribute_utils.get_distribution_strategy(False)
with self.assertRaisesRegexp(
ValueError, 'distribution_strategy must be a string but got: 1'):
with self.assertRaisesRegex(
ValueError, 'distribution_strategy must be a string but got: 1'
):
distribute_utils.get_distribution_strategy(1)

def test_get_strategy_scope(self):
Expand Down
6 changes: 3 additions & 3 deletions official/recommendation/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _test_end_to_end(self, constructor_type):
train_examples[l].add((u_raw, i_raw))
counts[(u_raw, i_raw)] += 1

self.assertRegexpMatches(md5.hexdigest(), END_TO_END_TRAIN_MD5)
self.assertRegex(md5.hexdigest(), END_TO_END_TRAIN_MD5)

num_positives_seen = len(train_examples[True])
self.assertEqual(producer._train_pos_users.shape[0], num_positives_seen)
Expand Down Expand Up @@ -254,7 +254,7 @@ def _test_end_to_end(self, constructor_type):
# from the negatives.
assert (u_raw, i_raw) not in self.seen_pairs

self.assertRegexpMatches(md5.hexdigest(), END_TO_END_EVAL_MD5)
self.assertRegex(md5.hexdigest(), END_TO_END_EVAL_MD5)

def _test_fresh_randomness(self, constructor_type):
train_epochs = 5
Expand Down Expand Up @@ -300,7 +300,7 @@ def _test_fresh_randomness(self, constructor_type):
else:
negative_counts[(u, i)] += 1

self.assertRegexpMatches(md5.hexdigest(), FRESH_RANDOMNESS_MD5)
self.assertRegex(md5.hexdigest(), FRESH_RANDOMNESS_MD5)

# The positive examples should appear exactly once each epoch
self.assertAllEqual(
Expand Down
4 changes: 2 additions & 2 deletions official/utils/misc/model_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_generate_synethetic_data(self):
for n in range(5):
inp, lab = sess.run((input_element, label_element))
self.assertAllClose(inp, [123., 123., 123., 123., 123.])
self.assertEquals(lab, 456)
self.assertEqual(lab, 456)

def test_generate_only_input_data(self):
d = model_helpers.generate_synthetic_data(
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_generate_nested_data(self):
element = tf.compat.v1.data.make_one_shot_iterator(d).get_next()
self.assertIn('a', element)
self.assertIn('b', element)
self.assertEquals(len(element['b']), 2)
self.assertEqual(len(element['b']), 2)
self.assertIn('c', element['b'])
self.assertIn('d', element['b'])
self.assertNotIn('c', element)
Expand Down

0 comments on commit 3d5e05f

Please sign in to comment.