Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristofferEdlund committed Oct 23, 2023
1 parent 00ceba9 commit 8b94f85
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions darwin/torch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ class ColorJitter(transforms.ColorJitter):
def __call__(
self, image: PILImage.Image, target: Optional[TargetType] = None
) -> Union[PILImage.Image, Tuple[PILImage.Image, TargetType]]:
transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
transform = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
image = transform(image)
if target is None:
return image
Expand Down Expand Up @@ -198,7 +200,9 @@ class ConvertPolygonsToInstanceMasks(object):
Converts given polygon to an ``InstanceMask``.
"""

def __call__(self, image: PILImage.Image, target: TargetType) -> Tuple[PILImage.Image, TargetType]:
def __call__(
self, image: PILImage.Image, target: TargetType
) -> Tuple[PILImage.Image, TargetType]:
w, h = image.size

image_id = target["image_id"]
Expand Down Expand Up @@ -255,7 +259,9 @@ class ConvertPolygonsToSemanticMask(object):
Converts given polygon to an ``SemanticMask``.
"""

def __call__(self, image: PILImage.Image, target: TargetType) -> Tuple[PILImage.Image, TargetType]:
def __call__(
self, image: PILImage.Image, target: TargetType
) -> Tuple[PILImage.Image, TargetType]:
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
Expand All @@ -282,7 +288,9 @@ class ConvertPolygonToMask(object):
Converts given polygon to a ``Mask``.
"""

def __call__(self, image: PILImage.Image, annotation: Dict[str, Any]) -> Tuple[PILImage.Image, PILImage.Image]:
def __call__(
self, image: PILImage.Image, annotation: Dict[str, Any]
) -> Tuple[PILImage.Image, PILImage.Image]:
w, h = image.size
segmentations = [obj["segmentation"] for obj in annotation]
cats = [obj["category_id"] for obj in annotation]
Expand Down Expand Up @@ -368,7 +376,9 @@ def _post_process(self, albumentation_output: dict, annotation: dict) -> tuple:
if bboxes is not None:
output_annotation["boxes"] = torch.tensor(bboxes)
if "area" in annotation and "masks" not in albumentation_output:
output_annotation["area"] = output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3]
output_annotation["area"] = (
output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3]
)

labels = albumentation_output.get("labels")
if labels is not None:
Expand All @@ -381,7 +391,9 @@ def _post_process(self, albumentation_output: dict, annotation: dict) -> tuple:
else:
output_annotation["masks"] = torch.stack(masks)
if "area" in annotation:
output_annotation["area"] = torch.sum(output_annotation["masks"], dim=[1, 2])
output_annotation["area"] = torch.sum(
output_annotation["masks"], dim=[1, 2]
)

# Copy other metadata from original annotation
for key, value in annotation.items():
Expand Down

0 comments on commit 8b94f85

Please sign in to comment.