diff --git a/darwin/backend_v2.py b/darwin/backend_v2.py index ea0d2cfdd..9e4e84d7a 100644 --- a/darwin/backend_v2.py +++ b/darwin/backend_v2.py @@ -253,3 +253,20 @@ def register_items(self, payload: Dict[str, Any], team_slug: str) -> None: return self._client._post_raw( f"/v2/teams/{team_slug}/items/register_existing", payload ) + + def _get_remote_annotations( + self, + item_id: str, + team_slug: str, + ) -> List: + """ + Returns the annotations currently present on a remote dataset item. + + Parameters + ---------- + item_id: str + The UUID of the item. + team_slug: str + The team slug. + """ + return self._client._get(f"v2/teams/{team_slug}/items/{item_id}/annotations") diff --git a/darwin/cli.py b/darwin/cli.py index cafb80cbf..003f20802 100644 --- a/darwin/cli.py +++ b/darwin/cli.py @@ -166,6 +166,7 @@ def _run(args: Namespace, parser: ArgumentParser) -> None: args.delete_for_empty, args.import_annotators, args.import_reviewers, + args.overwrite, cpu_limit=args.cpu_limit, ) elif args.action == "convert": diff --git a/darwin/cli_functions.py b/darwin/cli_functions.py index f814217dd..efde7456f 100644 --- a/darwin/cli_functions.py +++ b/darwin/cli_functions.py @@ -851,6 +851,7 @@ def dataset_import( delete_for_empty: bool = False, import_annotators: bool = False, import_reviewers: bool = False, + overwrite: bool = False, use_multi_cpu: bool = False, cpu_limit: Optional[int] = None, ) -> None: @@ -881,6 +882,9 @@ def dataset_import( import_reviewers : bool, default: False If ``True`` it will import the reviewers from the files to the dataset, if . If ``False`` it will not import the reviewers. + overwrite : bool, default: False + If ``True`` it will bypass a warning that the import will overwrite the current annotations if any are present. + If ``False`` this warning will be skipped and the import will overwrite the current annotations without warning. use_multi_cpu : bool, default: False If ``True`` it will use all multiple CPUs to speed up the import process. cpu_limit : Optional[int], default: Core count - 2 @@ -904,6 +908,7 @@ def dataset_import( delete_for_empty, import_annotators, import_reviewers, + overwrite, use_multi_cpu, cpu_limit, ) diff --git a/darwin/importer/importer.py b/darwin/importer/importer.py index 62a4d2e82..463a35791 100644 --- a/darwin/importer/importer.py +++ b/darwin/importer/importer.py @@ -182,7 +182,7 @@ def _build_attribute_lookup(dataset: "RemoteDataset") -> Dict[str, Unknown]: def _get_remote_files( dataset: "RemoteDataset", filenames: List[str], chunk_size: int = 100 -) -> Dict[str, Tuple[int, str]]: +) -> Dict[str, Tuple[str, str]]: """ Fetches remote files from the datasets in chunks; by default 100 filenames at a time. @@ -673,6 +673,7 @@ def import_annotations( # noqa: C901 delete_for_empty: bool = False, import_annotators: bool = False, import_reviewers: bool = False, + overwrite: bool = False, use_multi_cpu: bool = False, # Set to False to give time to resolve MP behaviours cpu_limit: Optional[int] = None, # 0 because it's set later in logic ) -> None: @@ -704,6 +705,9 @@ def import_annotations( # noqa: C901 import_reviewers : bool, default: False If ``True`` it will import the reviewers from the files to the dataset, if . If ``False`` it will not import the reviewers. + overwrite : bool, default: False + If ``True`` it will bypass a warning that the import will overwrite the current annotations if any are present. + If ``False`` this warning will be skipped and the import will overwrite the current annotations without warning. use_multi_cpu : bool, default: True If ``True`` will use multiple available CPU cores to parse the annotation files. If ``False`` will use only the current Python process, which runs in one core. @@ -790,7 +794,7 @@ def import_annotations( # noqa: C901 console.print("Fetching remote file list...", style="info") # This call will only filter by filename; so can return a superset of matched files across different paths # There is logic in this function to then include paths to narrow down to the single correct matching file - remote_files: Dict[str, Tuple[int, str]] = {} + remote_files: Dict[str, Tuple[str, str]] = {} # Try to fetch files in large chunks; in case the filenames are too large and exceed the url size # retry in smaller chunks @@ -901,6 +905,13 @@ def import_annotations( # noqa: C901 style="info", ) + if not append and not overwrite: + continue_to_overwrite = _overwrite_warning( + dataset.client, dataset, local_files, remote_files, console + ) + if not continue_to_overwrite: + return + # Need to re parse the files since we didn't save the annotations in memory for local_path in set(local_file.path for local_file in local_files): # noqa: C401 imported_files: Union[List[dt.AnnotationFile], dt.AnnotationFile, None] = ( @@ -1355,3 +1366,53 @@ def _console_theme() -> Theme: "info": "bold deep_sky_blue1", } ) + + +def _overwrite_warning( + client: "Client", + dataset: "RemoteDataset", + local_files: List[dt.AnnotationFile], + remote_files: Dict[str, Tuple[str, str]], + console: Console, +) -> bool: + """ + Determines if any dataset items targeted for import already have annotations that will be overwritten. + If they do, a warning is displayed to the user and they are prompted to confirm if they want to proceed with the import. + + Parameters + ---------- + client : Client + The Darwin Client object. + dataset : RemoteDataset + The dataset where the annotations will be imported. + files : List[dt.AnnotationFile] + The list of local annotation files to will be imported. + remote_files : Dict[str, Tuple[str, str]] + A dictionary of the remote files in the dataset. + console : Console + The console object. + + Returns + ------- + bool + True if the user wants to proceed with the import, False otherwise. + """ + files_to_overwrite = [] + for local_file in local_files: + remote_annotations = client.api_v2._get_remote_annotations( + local_file.item_id, + dataset.team, + ) + if remote_annotations and local_file.full_path not in files_to_overwrite: + files_to_overwrite.append(local_file.full_path) + if files_to_overwrite: + console.print( + f"The following {len(files_to_overwrite)} dataset items already have annotations that will be overwritten by this import:", + style="warning", + ) + for file in files_to_overwrite: + console.print(f"- {file}", style="warning") + proceed = input("Do you want to proceed with the import? [y/N] ") + if proceed.lower() != "y": + return False + return True diff --git a/darwin/options.py b/darwin/options.py index 0a9b5151f..1411a8385 100644 --- a/darwin/options.py +++ b/darwin/options.py @@ -333,6 +333,11 @@ def __init__(self) -> None: action="store_true", help="Import reviewers metadata from the annotation files, where available", ) + parser_import.add_argument( + "--overwrite", + action="store_true", + help="Bypass warnings about overwiting existing annotations.", + ) # Cpu limit for multiprocessing tasks def cpu_default_types(input: Any) -> Optional[int]: # type: ignore diff --git a/tests/darwin/importer/importer_test.py b/tests/darwin/importer/importer_test.py index 5e2534a30..d86a552c1 100644 --- a/tests/darwin/importer/importer_test.py +++ b/tests/darwin/importer/importer_test.py @@ -1,13 +1,13 @@ import json from pathlib import Path from typing import List, Tuple -from unittest.mock import Mock, _patch, patch +from unittest.mock import MagicMock, Mock, _patch, patch import pytest from rich.theme import Theme from darwin import datatypes as dt -from darwin.importer.importer import _parse_empty_masks +from darwin.importer.importer import _overwrite_warning, _parse_empty_masks def root_path(x: str) -> str: @@ -593,3 +593,91 @@ def test_console_theme() -> None: from darwin.importer.importer import _console_theme assert isinstance(_console_theme(), Theme) + + +def test_overwrite_warning_proceeds_with_import(): + annotations: List[dt.AnnotationLike] = [ + dt.Annotation( + dt.AnnotationClass("cat1", "polygon"), + { + "paths": [ + [ + {"x": -1, "y": -1}, + {"x": -1, "y": 1}, + {"x": 1, "y": 1}, + {"x": 1, "y": -1}, + {"x": -1, "y": -1}, + ] + ], + "bounding_box": {"x": -1, "y": -1, "w": 2, "h": 2}, + }, + ) + ] + client = MagicMock() + dataset = MagicMock() + files = [ + dt.AnnotationFile( + path=Path("/"), + filename="file1", + annotation_classes={a.annotation_class for a in annotations}, + annotations=annotations, + remote_path="/", + ), + dt.AnnotationFile( + path=Path("/"), + filename="file2", + annotation_classes={a.annotation_class for a in annotations}, + annotations=annotations, + remote_path="/", + ), + ] + remote_files = {"/file1": ("id1", "path1"), "/file2": ("id2", "path2")} + console = MagicMock() + + with patch("builtins.input", return_value="y"): + result = _overwrite_warning(client, dataset, files, remote_files, console) + assert result is True + + +def test_overwrite_warning_aborts_import(): + annotations: List[dt.AnnotationLike] = [ + dt.Annotation( + dt.AnnotationClass("cat1", "polygon"), + { + "paths": [ + [ + {"x": -1, "y": -1}, + {"x": -1, "y": 1}, + {"x": 1, "y": 1}, + {"x": 1, "y": -1}, + {"x": -1, "y": -1}, + ] + ], + "bounding_box": {"x": -1, "y": -1, "w": 2, "h": 2}, + }, + ) + ] + client = MagicMock() + dataset = MagicMock() + files = [ + dt.AnnotationFile( + path=Path("/"), + filename="file1", + annotation_classes={a.annotation_class for a in annotations}, + annotations=annotations, + remote_path="/", + ), + dt.AnnotationFile( + path=Path("/"), + filename="file2", + annotation_classes={a.annotation_class for a in annotations}, + annotations=annotations, + remote_path="/", + ), + ] + remote_files = {"/file1": ("id1", "path1"), "/file2": ("id2", "path2")} + console = MagicMock() + + with patch("builtins.input", return_value="n"): + result = _overwrite_warning(client, dataset, files, remote_files, console) + assert result is False