Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristofferEdlund committed Oct 23, 2023
1 parent 5a5101f commit a8d5805
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions darwin/torch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ def __call__(self, image, annotation: dict = None) -> tuple:

albu_data = self._pre_process(np_image, annotation_dict)
transformed_data = self.transform(**albu_data)
image, transformed_annotation = self._post_process(transformed_data, annotation_dict)
image, transformed_annotation = self._post_process(
transformed_data, annotation_dict
)

if annotation is None:
return image
Expand All @@ -363,7 +365,9 @@ def _pre_process(self, image: np.ndarray, annotation: dict) -> dict:
albumentation_dict["labels"] = labels.tolist()

masks = annotation.get("masks")
if masks is not None and masks.numel() > 0: # using numel() to check if tensor is non-empty
if (
masks is not None and masks.numel() > 0
): # using numel() to check if tensor is non-empty
print("WE GOT MASKS")
albumentation_dict["masks"] = masks.numpy()

Expand Down Expand Up @@ -395,9 +399,13 @@ def _post_process(self, albumentation_output: dict, annotation: dict) -> tuple:

if "area" in annotation:
if "masks" in output_annotation and output_annotation["masks"].numel() > 0:
output_annotation["area"] = torch.sum(output_annotation["masks"], dim=[1, 2])
output_annotation["area"] = torch.sum(
output_annotation["masks"], dim=[1, 2]
)
elif "boxes" in output_annotation and len(output_annotation["boxes"]) > 0:
output_annotation["area"] = output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3]
output_annotation["area"] = (
output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3]
)
else:
output_annotation["area"] = torch.tensor([])

Expand Down

0 comments on commit a8d5805

Please sign in to comment.