From e0a8ccb284253cd448ec7f6fa281afc55de05f2f Mon Sep 17 00:00:00 2001 From: John Wilkie <124276291+JBWilkie@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:25:06 +0100 Subject: [PATCH] Fixed bug when importing annotations with properties to non-default team (#813) --- darwin/client.py | 6 ++-- darwin/future/core/client.py | 6 ++-- darwin/future/tests/core/test_client.py | 5 ++-- darwin/importer/importer.py | 40 +++++++++++++++---------- tests/darwin/client_test.py | 3 +- tests/darwin/importer/importer_test.py | 1 + 6 files changed, 37 insertions(+), 24 deletions(-) diff --git a/darwin/client.py b/darwin/client.py index e311917d1..f83fd2309 100644 --- a/darwin/client.py +++ b/darwin/client.py @@ -1023,7 +1023,7 @@ def api_v2(self) -> BackendV2: def get_team_properties( self, team_slug: Optional[str] = None, include_property_values: bool = True ) -> List[FullProperty]: - darwin_config = DarwinConfig.from_old(self.config) + darwin_config = DarwinConfig.from_old(self.config, team_slug) future_client = ClientCore(darwin_config) if not include_property_values: @@ -1040,7 +1040,7 @@ def get_team_properties( def create_property( self, team_slug: Optional[str], params: Union[FullProperty, JSONDict] ) -> FullProperty: - darwin_config = DarwinConfig.from_old(self.config) + darwin_config = DarwinConfig.from_old(self.config, team_slug) future_client = ClientCore(darwin_config) return create_property_future( @@ -1052,7 +1052,7 @@ def create_property( def update_property( self, team_slug: Optional[str], params: Union[FullProperty, JSONDict] ) -> FullProperty: - darwin_config = DarwinConfig.from_old(self.config) + darwin_config = DarwinConfig.from_old(self.config, team_slug) future_client = ClientCore(darwin_config) return update_property_future( diff --git a/darwin/future/core/client.py b/darwin/future/core/client.py index 7da8ae1f4..847de0678 100644 --- a/darwin/future/core/client.py +++ b/darwin/future/core/client.py @@ -126,7 +126,7 @@ def from_api_key_with_defaults(api_key: str) -> DarwinConfig: ) @staticmethod - def from_old(old_config: OldConfig) -> DarwinConfig: + def from_old(old_config: OldConfig, team_slug: str) -> DarwinConfig: teams = old_config.get("teams") if not teams: raise ValueError("No teams found in the old config") @@ -136,12 +136,12 @@ def from_old(old_config: OldConfig) -> DarwinConfig: default_team = list(teams.keys())[0] return DarwinConfig( - api_key=teams[default_team]["api_key"], + api_key=teams[team_slug]["api_key"], api_endpoint=old_config.get("global/api_endpoint"), base_url=old_config.get("global/base_url"), default_team=default_team, teams=teams, - datasets_dir=teams[default_team]["datasets_dir"], + datasets_dir=teams[team_slug]["datasets_dir"], ) model_config = ConfigDict(validate_assignment=True) diff --git a/darwin/future/tests/core/test_client.py b/darwin/future/tests/core/test_client.py index ffd524231..258c2d653 100644 --- a/darwin/future/tests/core/test_client.py +++ b/darwin/future/tests/core/test_client.py @@ -149,8 +149,9 @@ def test_config_from_old_error( base_config: DarwinConfig, darwin_config_path: Path ) -> None: old_config = OldConfig(darwin_config_path) + team_slug = "test-team" with pytest.raises(ValueError) as excinfo: - base_config.from_old(old_config) + base_config.from_old(old_config, team_slug) (msg,) = excinfo.value.args assert msg == "No teams found in the old config" @@ -165,7 +166,7 @@ def test_config_from_old( old_config.put(["global", "base_url"], "http://localhost") old_config.put(["teams", team_slug, "api_key"], "mock_api_key") old_config.put(["teams", team_slug, "datasets_dir"], str(darwin_datasets_path)) - darwin_config = base_config.from_old(old_config) + darwin_config = base_config.from_old(old_config, team_slug) assert darwin_config.api_key == "mock_api_key" assert darwin_config.base_url == "http://localhost/" diff --git a/darwin/importer/importer.py b/darwin/importer/importer.py index d6966a8d7..984e89219 100644 --- a/darwin/importer/importer.py +++ b/darwin/importer/importer.py @@ -277,14 +277,14 @@ def _resolve_annotation_classes( return local_classes_not_in_dataset, local_classes_not_in_team -def _get_team_properties_annotation_lookup(client): +def _get_team_properties_annotation_lookup(client, team_slug): # get team properties -> List[FullProperty] - team_properties = client.get_team_properties() + team_properties = client.get_team_properties(team_slug) # (property-name, annotation_class_id): FullProperty object - team_properties_annotation_lookup: Dict[Tuple[str, Optional[int]], FullProperty] = ( - {} - ) + team_properties_annotation_lookup: Dict[ + Tuple[str, Optional[int]], FullProperty + ] = {} for prop in team_properties: team_properties_annotation_lookup[(prop.name, prop.annotation_class_id)] = prop @@ -322,6 +322,7 @@ def _import_properties( client: "Client", annotations: List[dt.Annotation], annotation_class_ids_map: Dict[Tuple[str, str], str], + team_slug: str, ) -> Dict[str, Dict[str, Dict[str, Set[str]]]]: """ Creates/Updates missing/mismatched properties from annotation & metadata.json file to team-properties. @@ -333,6 +334,7 @@ def _import_properties( client (Client): Darwin Client object annotations (List[dt.Annotation]): List of annotations annotation_class_ids_map (Dict[Tuple[str, str], str]): Dict of annotation class names/types to annotation class ids + team_slug (str): Team slug Raises: ValueError: raise error if annotation class not present in metadata and in team-properties @@ -352,7 +354,9 @@ def _import_properties( metadata_property_classes = parse_property_classes(metadata) # get team properties - team_properties_annotation_lookup = _get_team_properties_annotation_lookup(client) + team_properties_annotation_lookup = _get_team_properties_annotation_lookup( + client, team_slug + ) # (annotation-cls-name, annotation-cls-name): PropertyClass object metadata_classes_lookup: Set[Tuple[str, str]] = set() @@ -397,7 +401,10 @@ def _import_properties( if (annotation_name, a_prop.name) not in metadata_cls_prop_lookup: # check if they are present in team properties - if (a_prop.name, annotation_class_id) in team_properties_annotation_lookup: + if ( + a_prop.name, + annotation_class_id, + ) in team_properties_annotation_lookup: # get team property t_prop: FullProperty = team_properties_annotation_lookup[ (a_prop.name, annotation_class_id) @@ -411,7 +418,7 @@ def _import_properties( ] = set() continue - # get team property value + # get team property value t_prop_val = None for prop_val in t_prop.property_values or []: if prop_val.value == a_prop.value: @@ -597,7 +604,9 @@ def _import_properties( updated_properties.append(prop) # get latest team properties - team_properties_annotation_lookup = _get_team_properties_annotation_lookup(client) + team_properties_annotation_lookup = _get_team_properties_annotation_lookup( + client, team_slug + ) # loop over metadata_cls_id_prop_lookup, and update additional metadata property values for (annotation_class_id, prop_name), m_prop in metadata_cls_id_prop_lookup.items(): @@ -921,9 +930,9 @@ def import_annotations( # noqa: C901 # Need to re parse the files since we didn't save the annotations in memory for local_path in set(local_file.path for local_file in local_files): # noqa: C401 - imported_files: Union[List[dt.AnnotationFile], dt.AnnotationFile, None] = ( - importer(local_path) - ) + imported_files: Union[ + List[dt.AnnotationFile], dt.AnnotationFile, None + ] = importer(local_path) if imported_files is None: parsed_files = [] elif not isinstance(imported_files, List): @@ -1323,9 +1332,9 @@ def _import_annotations( # Insert the default slot name if not available in the import source annotation = _handle_slot_names(annotation, dataset.version, default_slot_name) - annotation_class_ids_map[(annotation_class.name, annotation_type)] = ( - annotation_class_id - ) + annotation_class_ids_map[ + (annotation_class.name, annotation_type) + ] = annotation_class_id serial_obj = { "annotation_class_id": annotation_class_id, "data": data, @@ -1345,6 +1354,7 @@ def _import_annotations( client, annotations, # type: ignore annotation_class_ids_map, + dataset.team, ) _update_payload_with_properties(serialized_annotations, annotation_id_property_map) diff --git a/tests/darwin/client_test.py b/tests/darwin/client_test.py index 1f4323155..1652c2f58 100644 --- a/tests/darwin/client_test.py +++ b/tests/darwin/client_test.py @@ -383,7 +383,8 @@ def test_get_team_properties(self, darwin_client: Client) -> None: }, status=200, ) - assert len(darwin_client.get_team_properties()) == 1 + team_slug = "v7-darwin-json-v2" + assert len(darwin_client.get_team_properties(team_slug)) == 1 @pytest.mark.usefixtures("file_read_write_test") diff --git a/tests/darwin/importer/importer_test.py b/tests/darwin/importer/importer_test.py index 8ef615e1d..15c87696f 100644 --- a/tests/darwin/importer/importer_test.py +++ b/tests/darwin/importer/importer_test.py @@ -501,6 +501,7 @@ def test__import_annotations() -> None: mock_dataset = Mock(RemoteDataset) mock_dataset.version = 2 + mock_dataset.team = "test_team" mock_hr.return_value = [ {"email": "reviewer1@example.com", "role": "reviewer"}, {"email": "reviewer2@example.com", "role": "reviewer"},