From 7446e37980a17e45d5958fae031d2d4addea75a8 Mon Sep 17 00:00:00 2001 From: Jing Huang Date: Tue, 8 Oct 2024 15:06:40 -0700 Subject: [PATCH] Fix flaky tests by relaxing the abs diff tolerance. --- .../ComplexInterventionWithGPT2TestCase.py | 8 +++++--- tests/integration_tests/InterventionWithGPT2TestCase.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py b/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py index 34067f4c..7900ce6c 100644 --- a/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py +++ b/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py @@ -63,10 +63,11 @@ def test_clean_run_positive(self): base = {"input_ids": torch.randint(0, 10, (10, 5)).to(self.device)} golden_out = self.gpt2(**base).logits our_output = intervenable(base, output_original_output=True)[0][0] - self.assertTrue(torch.allclose(golden_out, our_output)) + self.assertTrue(torch.allclose(golden_out, our_output, rtol=1e-05, atol=1e-06)) # make sure the toolkit also works self.assertTrue( - torch.allclose(GPT2_RUN(self.gpt2, base["input_ids"], {}, {}), golden_out) + torch.allclose(GPT2_RUN(self.gpt2, base["input_ids"], {}, {}), golden_out, + rtol=1e-05, atol=1e-06) ) def _test_subspace_partition_in_forward(self, intervention_type): @@ -134,7 +135,8 @@ def _test_subspace_partition_in_forward(self, intervention_type): # make sure the toolkit also works self.assertTrue( torch.allclose( - with_partition_our_output[0], without_partition_our_output[0] + with_partition_our_output[0], without_partition_our_output[0], + rtol=1e-05, atol=1e-06 ) ) diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index 723c8286..1bc5fd94 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -132,8 +132,9 @@ def _test_with_head_position_intervention( ) }, ) - - self.assertTrue(torch.allclose(out_output[0], golden_out)) + # Relax the atol to 1e-6 to accommodate for different Transformers versions. + # The max of the absolute diff is usually between 1e-8 to 1e-7. + self.assertTrue(torch.allclose(out_output[0], golden_out, rtol=1e-05, atol=1e-06)) def test_with_multiple_heads_positions_vanilla_intervention_positive(self):