Skip to content

Commit

Permalink
run
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jan 30, 2025
1 parent da2b134 commit bf377ab
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/models/timm_backbone/test_modeling_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def test_config(self):
@is_flaky(
description="`TimmBackbone` has no `_init_weights`. Timm's way of weight init. seems to give larger magnitude in the intermediate values during `forward`."
)
def test_batching_equivalence(self):
super().test_batching_equivalence()
def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)

def test_timm_transformer_backbone_equivalence(self):
timm_checkpoint = "resnet18"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def check_determinism(first, second):
else:
check_determinism(first, second)

def test_batching_equivalence(self):
def test_batching_equivalence(self, atol=1e-5, rtol=1e-5):
"""
Tests that the model supports batching and that the output is the nearly the same for the same input in
different batch sizes.
Expand Down Expand Up @@ -812,7 +812,7 @@ def recursive_check(batched_object, single_row_object, model_name, key):
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
)
try:
torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(batched_row, single_row_object, atol=atol, rtol=rtol)
except AssertionError as e:
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
msg += str(e)
Expand Down

0 comments on commit bf377ab

Please sign in to comment.