diff --git a/darwin/torch/transforms.py b/darwin/torch/transforms.py index 8285638ff..c9b39eb80 100644 --- a/darwin/torch/transforms.py +++ b/darwin/torch/transforms.py @@ -335,12 +335,14 @@ def from_dict(cls, alb_dict: dict) -> "AlbumentationsTransform": def __call__(self, image, annotation: dict = None) -> tuple: np_image = np.array(image) if annotation is None: - annotation = {} - albu_data = self._pre_process(np_image, annotation) + annotation_dict = {} + else: + annotation_dict = annotation + albu_data = self._pre_process(np_image, annotation_dict) transformed_data = self.transform(**albu_data) - image, transformed_annotation = self._post_process(transformed_data, annotation) + image, transformed_annotation = self._post_process(transformed_data, annotation_dict) - if len(annotation) < 1: + if annotation is None: return image return image, transformed_annotation