Skip to content

Commit

Permalink
Fixed bug when importing annotations with properties to non-default t…
Browse files Browse the repository at this point in the history
…eam (#813)
  • Loading branch information
JBWilkie authored Apr 12, 2024
1 parent 3c4263c commit e0a8ccb
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 24 deletions.
6 changes: 3 additions & 3 deletions darwin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def api_v2(self) -> BackendV2:
def get_team_properties(
self, team_slug: Optional[str] = None, include_property_values: bool = True
) -> List[FullProperty]:
darwin_config = DarwinConfig.from_old(self.config)
darwin_config = DarwinConfig.from_old(self.config, team_slug)
future_client = ClientCore(darwin_config)

if not include_property_values:
Expand All @@ -1040,7 +1040,7 @@ def get_team_properties(
def create_property(
self, team_slug: Optional[str], params: Union[FullProperty, JSONDict]
) -> FullProperty:
darwin_config = DarwinConfig.from_old(self.config)
darwin_config = DarwinConfig.from_old(self.config, team_slug)
future_client = ClientCore(darwin_config)

return create_property_future(
Expand All @@ -1052,7 +1052,7 @@ def create_property(
def update_property(
self, team_slug: Optional[str], params: Union[FullProperty, JSONDict]
) -> FullProperty:
darwin_config = DarwinConfig.from_old(self.config)
darwin_config = DarwinConfig.from_old(self.config, team_slug)
future_client = ClientCore(darwin_config)

return update_property_future(
Expand Down
6 changes: 3 additions & 3 deletions darwin/future/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def from_api_key_with_defaults(api_key: str) -> DarwinConfig:
)

@staticmethod
def from_old(old_config: OldConfig) -> DarwinConfig:
def from_old(old_config: OldConfig, team_slug: str) -> DarwinConfig:
teams = old_config.get("teams")
if not teams:
raise ValueError("No teams found in the old config")
Expand All @@ -136,12 +136,12 @@ def from_old(old_config: OldConfig) -> DarwinConfig:
default_team = list(teams.keys())[0]

return DarwinConfig(
api_key=teams[default_team]["api_key"],
api_key=teams[team_slug]["api_key"],
api_endpoint=old_config.get("global/api_endpoint"),
base_url=old_config.get("global/base_url"),
default_team=default_team,
teams=teams,
datasets_dir=teams[default_team]["datasets_dir"],
datasets_dir=teams[team_slug]["datasets_dir"],
)

model_config = ConfigDict(validate_assignment=True)
Expand Down
5 changes: 3 additions & 2 deletions darwin/future/tests/core/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ def test_config_from_old_error(
base_config: DarwinConfig, darwin_config_path: Path
) -> None:
old_config = OldConfig(darwin_config_path)
team_slug = "test-team"
with pytest.raises(ValueError) as excinfo:
base_config.from_old(old_config)
base_config.from_old(old_config, team_slug)
(msg,) = excinfo.value.args
assert msg == "No teams found in the old config"

Expand All @@ -165,7 +166,7 @@ def test_config_from_old(
old_config.put(["global", "base_url"], "http://localhost")
old_config.put(["teams", team_slug, "api_key"], "mock_api_key")
old_config.put(["teams", team_slug, "datasets_dir"], str(darwin_datasets_path))
darwin_config = base_config.from_old(old_config)
darwin_config = base_config.from_old(old_config, team_slug)

assert darwin_config.api_key == "mock_api_key"
assert darwin_config.base_url == "http://localhost/"
Expand Down
40 changes: 25 additions & 15 deletions darwin/importer/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,14 @@ def _resolve_annotation_classes(
return local_classes_not_in_dataset, local_classes_not_in_team


def _get_team_properties_annotation_lookup(client):
def _get_team_properties_annotation_lookup(client, team_slug):
# get team properties -> List[FullProperty]
team_properties = client.get_team_properties()
team_properties = client.get_team_properties(team_slug)

# (property-name, annotation_class_id): FullProperty object
team_properties_annotation_lookup: Dict[Tuple[str, Optional[int]], FullProperty] = (
{}
)
team_properties_annotation_lookup: Dict[
Tuple[str, Optional[int]], FullProperty
] = {}
for prop in team_properties:
team_properties_annotation_lookup[(prop.name, prop.annotation_class_id)] = prop

Expand Down Expand Up @@ -322,6 +322,7 @@ def _import_properties(
client: "Client",
annotations: List[dt.Annotation],
annotation_class_ids_map: Dict[Tuple[str, str], str],
team_slug: str,
) -> Dict[str, Dict[str, Dict[str, Set[str]]]]:
"""
Creates/Updates missing/mismatched properties from annotation & metadata.json file to team-properties.
Expand All @@ -333,6 +334,7 @@ def _import_properties(
client (Client): Darwin Client object
annotations (List[dt.Annotation]): List of annotations
annotation_class_ids_map (Dict[Tuple[str, str], str]): Dict of annotation class names/types to annotation class ids
team_slug (str): Team slug
Raises:
ValueError: raise error if annotation class not present in metadata and in team-properties
Expand All @@ -352,7 +354,9 @@ def _import_properties(
metadata_property_classes = parse_property_classes(metadata)

# get team properties
team_properties_annotation_lookup = _get_team_properties_annotation_lookup(client)
team_properties_annotation_lookup = _get_team_properties_annotation_lookup(
client, team_slug
)

# (annotation-cls-name, annotation-cls-name): PropertyClass object
metadata_classes_lookup: Set[Tuple[str, str]] = set()
Expand Down Expand Up @@ -397,7 +401,10 @@ def _import_properties(
if (annotation_name, a_prop.name) not in metadata_cls_prop_lookup:

# check if they are present in team properties
if (a_prop.name, annotation_class_id) in team_properties_annotation_lookup:
if (
a_prop.name,
annotation_class_id,
) in team_properties_annotation_lookup:
# get team property
t_prop: FullProperty = team_properties_annotation_lookup[
(a_prop.name, annotation_class_id)
Expand All @@ -411,7 +418,7 @@ def _import_properties(
] = set()
continue

# get team property value
# get team property value
t_prop_val = None
for prop_val in t_prop.property_values or []:
if prop_val.value == a_prop.value:
Expand Down Expand Up @@ -597,7 +604,9 @@ def _import_properties(
updated_properties.append(prop)

# get latest team properties
team_properties_annotation_lookup = _get_team_properties_annotation_lookup(client)
team_properties_annotation_lookup = _get_team_properties_annotation_lookup(
client, team_slug
)

# loop over metadata_cls_id_prop_lookup, and update additional metadata property values
for (annotation_class_id, prop_name), m_prop in metadata_cls_id_prop_lookup.items():
Expand Down Expand Up @@ -921,9 +930,9 @@ def import_annotations( # noqa: C901

# Need to re parse the files since we didn't save the annotations in memory
for local_path in set(local_file.path for local_file in local_files): # noqa: C401
imported_files: Union[List[dt.AnnotationFile], dt.AnnotationFile, None] = (
importer(local_path)
)
imported_files: Union[
List[dt.AnnotationFile], dt.AnnotationFile, None
] = importer(local_path)
if imported_files is None:
parsed_files = []
elif not isinstance(imported_files, List):
Expand Down Expand Up @@ -1323,9 +1332,9 @@ def _import_annotations(
# Insert the default slot name if not available in the import source
annotation = _handle_slot_names(annotation, dataset.version, default_slot_name)

annotation_class_ids_map[(annotation_class.name, annotation_type)] = (
annotation_class_id
)
annotation_class_ids_map[
(annotation_class.name, annotation_type)
] = annotation_class_id
serial_obj = {
"annotation_class_id": annotation_class_id,
"data": data,
Expand All @@ -1345,6 +1354,7 @@ def _import_annotations(
client,
annotations, # type: ignore
annotation_class_ids_map,
dataset.team,
)
_update_payload_with_properties(serialized_annotations, annotation_id_property_map)

Expand Down
3 changes: 2 additions & 1 deletion tests/darwin/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,8 @@ def test_get_team_properties(self, darwin_client: Client) -> None:
},
status=200,
)
assert len(darwin_client.get_team_properties()) == 1
team_slug = "v7-darwin-json-v2"
assert len(darwin_client.get_team_properties(team_slug)) == 1


@pytest.mark.usefixtures("file_read_write_test")
Expand Down
1 change: 1 addition & 0 deletions tests/darwin/importer/importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def test__import_annotations() -> None:
mock_dataset = Mock(RemoteDataset)

mock_dataset.version = 2
mock_dataset.team = "test_team"
mock_hr.return_value = [
{"email": "[email protected]", "role": "reviewer"},
{"email": "[email protected]", "role": "reviewer"},
Expand Down

0 comments on commit e0a8ccb

Please sign in to comment.