diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 296a38c176ef..9701e9d68290 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -118,8 +118,8 @@ def test_config(self): @is_flaky( description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`." ) - def test_batching_equivalence(self): - super().test_batching_equivalence() + def test_batching_equivalence(self, atol=1e-4, rtol=1e-4): + super().test_batching_equivalence(atol=atol, rtol=rtol) def test_timm_transformer_backbone_equivalence(self): timm_checkpoint = "resnet18" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d52e8281449..bd56df3abc0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -768,7 +768,7 @@ def check_determinism(first, second): else: check_determinism(first, second) - def test_batching_equivalence(self): + def test_batching_equivalence(self, atol=1e-5, rtol=1e-5): """ Tests that the model supports batching and that the output is the nearly the same for the same input in different batch sizes. @@ -812,7 +812,7 @@ def recursive_check(batched_object, single_row_object, model_name, key): torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" ) try: - torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol) except AssertionError as e: msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n" msg += str(e)