Skip to content

Commit

Permalink
Fix flaky tests by relaxing the abs diff tolerance.
Browse files Browse the repository at this point in the history
  • Loading branch information
explanare committed Oct 8, 2024
1 parent e19c97b commit 7446e37
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ def test_clean_run_positive(self):
base = {"input_ids": torch.randint(0, 10, (10, 5)).to(self.device)}
golden_out = self.gpt2(**base).logits
our_output = intervenable(base, output_original_output=True)[0][0]
self.assertTrue(torch.allclose(golden_out, our_output))
self.assertTrue(torch.allclose(golden_out, our_output, rtol=1e-05, atol=1e-06))
# make sure the toolkit also works
self.assertTrue(
torch.allclose(GPT2_RUN(self.gpt2, base["input_ids"], {}, {}), golden_out)
torch.allclose(GPT2_RUN(self.gpt2, base["input_ids"], {}, {}), golden_out,
rtol=1e-05, atol=1e-06)
)

def _test_subspace_partition_in_forward(self, intervention_type):
Expand Down Expand Up @@ -134,7 +135,8 @@ def _test_subspace_partition_in_forward(self, intervention_type):
# make sure the toolkit also works
self.assertTrue(
torch.allclose(
with_partition_our_output[0], without_partition_our_output[0]
with_partition_our_output[0], without_partition_our_output[0],
rtol=1e-05, atol=1e-06
)
)

Expand Down
5 changes: 3 additions & 2 deletions tests/integration_tests/InterventionWithGPT2TestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def _test_with_head_position_intervention(
)
},
)

self.assertTrue(torch.allclose(out_output[0], golden_out))
# Relax the atol to 1e-6 to accommodate for different Transformers versions.
# The max of the absolute diff is usually between 1e-8 to 1e-7.
self.assertTrue(torch.allclose(out_output[0], golden_out, rtol=1e-05, atol=1e-06))


def test_with_multiple_heads_positions_vanilla_intervention_positive(self):
Expand Down

0 comments on commit 7446e37

Please sign in to comment.