From 8b94f853b7f8605848bb5c1843e55b2e476ce0e8 Mon Sep 17 00:00:00 2001 From: Christoffer Date: Mon, 23 Oct 2023 13:11:37 +0200 Subject: [PATCH] black --- darwin/torch/transforms.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/darwin/torch/transforms.py b/darwin/torch/transforms.py index 458a4454c..8285638ff 100644 --- a/darwin/torch/transforms.py +++ b/darwin/torch/transforms.py @@ -143,7 +143,9 @@ class ColorJitter(transforms.ColorJitter): def __call__( self, image: PILImage.Image, target: Optional[TargetType] = None ) -> Union[PILImage.Image, Tuple[PILImage.Image, TargetType]]: - transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + transform = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) image = transform(image) if target is None: return image @@ -198,7 +200,9 @@ class ConvertPolygonsToInstanceMasks(object): Converts given polygon to an ``InstanceMask``. """ - def __call__(self, image: PILImage.Image, target: TargetType) -> Tuple[PILImage.Image, TargetType]: + def __call__( + self, image: PILImage.Image, target: TargetType + ) -> Tuple[PILImage.Image, TargetType]: w, h = image.size image_id = target["image_id"] @@ -255,7 +259,9 @@ class ConvertPolygonsToSemanticMask(object): Converts given polygon to an ``SemanticMask``. """ - def __call__(self, image: PILImage.Image, target: TargetType) -> Tuple[PILImage.Image, TargetType]: + def __call__( + self, image: PILImage.Image, target: TargetType + ) -> Tuple[PILImage.Image, TargetType]: w, h = image.size image_id = target["image_id"] image_id = torch.tensor([image_id]) @@ -282,7 +288,9 @@ class ConvertPolygonToMask(object): Converts given polygon to a ``Mask``. """ - def __call__(self, image: PILImage.Image, annotation: Dict[str, Any]) -> Tuple[PILImage.Image, PILImage.Image]: + def __call__( + self, image: PILImage.Image, annotation: Dict[str, Any] + ) -> Tuple[PILImage.Image, PILImage.Image]: w, h = image.size segmentations = [obj["segmentation"] for obj in annotation] cats = [obj["category_id"] for obj in annotation] @@ -368,7 +376,9 @@ def _post_process(self, albumentation_output: dict, annotation: dict) -> tuple: if bboxes is not None: output_annotation["boxes"] = torch.tensor(bboxes) if "area" in annotation and "masks" not in albumentation_output: - output_annotation["area"] = output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3] + output_annotation["area"] = ( + output_annotation["boxes"][:, 2] * output_annotation["boxes"][:, 3] + ) labels = albumentation_output.get("labels") if labels is not None: @@ -381,7 +391,9 @@ def _post_process(self, albumentation_output: dict, annotation: dict) -> tuple: else: output_annotation["masks"] = torch.stack(masks) if "area" in annotation: - output_annotation["area"] = torch.sum(output_annotation["masks"], dim=[1, 2]) + output_annotation["area"] = torch.sum( + output_annotation["masks"], dim=[1, 2] + ) # Copy other metadata from original annotation for key, value in annotation.items():