Skip to content

Commit

Permalink
made the refactor more like the original
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristofferEdlund committed Oct 17, 2023
1 parent 57797ca commit a4431f8
Showing 1 changed file with 18 additions and 44 deletions.
62 changes: 18 additions & 44 deletions darwin/dataset/local_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,6 @@ def __init__(
split_type: str = "random",
release_name: Optional[str] = None,
):
assert dataset_path is not None
release_path = get_release_path(dataset_path, release_name)
annotations_dir = release_path / "annotations"
assert annotations_dir.exists()
images_dir = dataset_path / "images"
assert images_dir.exists()

if partition not in ["train", "val", "test", None]:
raise ValueError("partition should be either 'train', 'val', or 'test'")
if split_type not in ["random", "stratified"]:
raise ValueError("split_type should be either 'random', 'stratified'")
if annotation_type not in ["tag", "polygon", "bounding_box"]:
raise ValueError("annotation_type should be either 'tag', 'bounding_box', or 'polygon'")

self.dataset_path = dataset_path
self.annotation_type = annotation_type
self.images_path: List[Path] = []
Expand All @@ -85,32 +71,17 @@ def __init__(
self.original_images_path: Optional[List[Path]] = None
self.original_annotations_path: Optional[List[Path]] = None

release_path, annotations_dir, images_dir = self._initial_setup(dataset_path, release_name)
self._validate_inputs(partition, split_type, annotation_type)
# Get the list of classes
self.classes = get_classes(
self.dataset_path, release_name, annotation_type=self.annotation_type, remove_background=True
)
self.num_classes = len(self.classes)

stems = build_stems(release_path, annotations_dir, annotation_type, split, partition, split_type)

# Find all the annotations and their corresponding images
for stem in stems:
annotation_path = annotations_dir / f"{stem}.json"
images = []
for ext in SUPPORTED_IMAGE_EXTENSIONS:
image_path = images_dir / f"{stem}{ext}"
if image_path.exists():
images.append(image_path)
continue
image_path = images_dir / f"{stem}{ext.upper()}"
if image_path.exists():
images.append(image_path)
if len(images) < 1:
raise ValueError(f"Annotation ({annotation_path}) does not have a corresponding image")
if len(images) > 1:
raise ValueError(f"Image ({stem}) is present with multiple extensions. This is forbidden.")
self.images_path.append(images[0])
self.annotations_path.append(annotation_path)
annotation_types = [self.annotation_type]
# We fetch bounding_boxes annotations from selected polygons as well
if self.annotation_type == "bounding_boxes":
annotation_types.append("polygon")
self.classes = get_classes(self.dataset_path, release_name, annotation_type=annotation_types, remove_background=True)
self.num_classes = len(self.classes)
self._setup_annotations_and_images(release_path, annotations_dir, images_dir, annotation_type, split, partition, split_type)

if len(self.images_path) == 0:
raise ValueError(f"Could not find any {SUPPORTED_IMAGE_EXTENSIONS} file", f" in {images_dir}")
Expand All @@ -129,12 +100,15 @@ def _setup_annotations_and_images(self, release_path, annotations_dir, images_di
stems = build_stems(release_path, annotations_dir, annotation_type, split, partition, split_type)
for stem in stems:
annotation_path = annotations_dir / f"{stem}.json"
images = [
image_path
for ext in SUPPORTED_IMAGE_EXTENSIONS
for image_path in [images_dir / f"{stem}{ext}", images_dir / f"{stem}{ext.upper()}"]
if image_path.exists()
]
images = []
for ext in SUPPORTED_IMAGE_EXTENSIONS:
image_path = images_dir / f"{stem}{ext}"
if image_path.exists():
images.append(image_path)
continue
image_path = images_dir / f"{stem}{ext.upper()}"
if image_path.exists():
images.append(image_path)
if len(images) < 1:
raise ValueError(f"Annotation ({annotation_path}) does not have a corresponding image")
if len(images) > 1:
Expand Down

0 comments on commit a4431f8

Please sign in to comment.