diff --git a/darwin/dataset/local_dataset.py b/darwin/dataset/local_dataset.py index ffbe6488c..abf47e416 100644 --- a/darwin/dataset/local_dataset.py +++ b/darwin/dataset/local_dataset.py @@ -63,6 +63,20 @@ def __init__( split_type: str = "random", release_name: Optional[str] = None, ): + assert dataset_path is not None + release_path = get_release_path(dataset_path, release_name) + annotations_dir = release_path / "annotations" + assert annotations_dir.exists() + images_dir = dataset_path / "images" + assert images_dir.exists() + + if partition not in ["train", "val", "test", None]: + raise ValueError("partition should be either 'train', 'val', or 'test'") + if split_type not in ["random", "stratified"]: + raise ValueError("split_type should be either 'random', 'stratified'") + if annotation_type not in ["tag", "polygon", "bounding_box"]: + raise ValueError("annotation_type should be either 'tag', 'bounding_box', or 'polygon'") + self.dataset_path = dataset_path self.annotation_type = annotation_type self.images_path: List[Path] = [] @@ -71,38 +85,35 @@ def __init__( self.original_images_path: Optional[List[Path]] = None self.original_annotations_path: Optional[List[Path]] = None - release_path, annotations_dir, images_dir = self._initial_setup( - dataset_path, release_name - ) - self._validate_inputs(partition, split_type, annotation_type) # Get the list of classes - - annotation_types = [self.annotation_type] - # We fetch bounding_boxes annotations from selected polygons as well - if self.annotation_type == "bounding_boxes": - annotation_types.append("polygon") self.classes = get_classes( - self.dataset_path, - release_name, - annotation_type=annotation_types, - remove_background=True, + self.dataset_path, release_name, annotation_type=self.annotation_type, remove_background=True ) self.num_classes = len(self.classes) - self._setup_annotations_and_images( - release_path, - annotations_dir, - images_dir, - annotation_type, - split, - partition, - split_type, - ) + + stems = build_stems(release_path, annotations_dir, annotation_type, split, partition, split_type) + + # Find all the annotations and their corresponding images + for stem in stems: + annotation_path = annotations_dir / f"{stem}.json" + images = [] + for ext in SUPPORTED_IMAGE_EXTENSIONS: + image_path = images_dir / f"{stem}{ext}" + if image_path.exists(): + images.append(image_path) + continue + image_path = images_dir / f"{stem}{ext.upper()}" + if image_path.exists(): + images.append(image_path) + if len(images) < 1: + raise ValueError(f"Annotation ({annotation_path}) does not have a corresponding image") + if len(images) > 1: + raise ValueError(f"Image ({stem}) is present with multiple extensions. This is forbidden.") + self.images_path.append(images[0]) + self.annotations_path.append(annotation_path) if len(self.images_path) == 0: - raise ValueError( - f"Could not find any {SUPPORTED_IMAGE_EXTENSIONS} file", - f" in {images_dir}", - ) + raise ValueError(f"Could not find any {SUPPORTED_IMAGE_EXTENSIONS} file", f" in {images_dir}") assert len(self.images_path) == len(self.annotations_path)