Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristofferEdlund committed Oct 23, 2023
1 parent bd59c78 commit 9ef6e56
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tests/darwin/torch/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
bbox_params=BboxParams(format="coco", label_fields=["labels"]),
)


class TestAlbumentationsTransform:
def test_init(self):
transformations = EXAMPLE_BOX_TRANSFORM
Expand Down Expand Up @@ -84,7 +85,9 @@ def test_boxes_out_of_bounds(self):
transformations = EXAMPLE_BOX_TRANSFORM
at = AlbumentationsTransform(transformations)
with pytest.raises(ValueError):
_, annotation = at(SAMPLE_IMAGE, SAMPLE_ANNOTATION_OOB) # Expecting the ValueError due to out of bounds
_, annotation = at(
SAMPLE_IMAGE, SAMPLE_ANNOTATION_OOB
) # Expecting the ValueError due to out of bounds

def test_transform_with_masks(self):
transformations = EXAMPLE_BOX_TRANSFORM
Expand All @@ -105,7 +108,9 @@ def test_area_calculation_without_masks(self):
_, annotation = at(SAMPLE_IMAGE, SAMPLE_ANNOTATION)
area = annotation["boxes"][0, 2] * annotation["boxes"][0, 3]

assert torch.isclose(annotation["area"], area.unsqueeze(0), atol=1e-5) # Using isclose for floating point comparison
assert torch.isclose(
annotation["area"], area.unsqueeze(0), atol=1e-5
) # Using isclose for floating point comparison

def test_iscrowd_unchanged(self):
transformations = EXAMPLE_BOX_TRANSFORM
Expand All @@ -117,14 +122,13 @@ def test_iscrowd_unchanged(self):
def test_image_only(self):
transformations = EXAMPLE_IMAGE_TRANSFORM
at = AlbumentationsTransform(transformations)

image = at(SAMPLE_IMAGE)
assert image is not None
assert not isinstance(image, Tuple)
print(type(image))
assert isinstance(image, np.ndarray)


def test_bbox_with_empty_annotation(self):
transformations = EXAMPLE_BOX_TRANSFORM
at = AlbumentationsTransform(transformations)
Expand All @@ -142,7 +146,7 @@ def test_mask_with_empty_annotation(self):
for key in SAMPLE_EMPTY_ANNOTATION_WITH_MASKS:
assert key in annotation.keys()
assert len(annotation[key]) == 0


if __name__ == "__main__":
pytest.run()

0 comments on commit 9ef6e56

Please sign in to comment.