diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index f2f83c3c..86993f6d 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json from os import PathLike from typing import Iterator @@ -37,6 +38,16 @@ def __eq__(self, other) -> bool: and self._strain_dict_name == other._strain_dict_name) return NotImplemented + def __add__(self, other) -> StrainCollection: + if isinstance(other, StrainCollection): + sc = StrainCollection() + for strain in self._strains: + sc.add(strain) + for strain in other._strains: + sc.add(strain) + return sc + return NotImplemented + def __contains__(self, item: Strain) -> bool: """Check if the strain collection contains the given Strain object. """ @@ -58,7 +69,10 @@ def add(self, strain: Strain) -> None: if strain in self._strains: # only one strain object per id strain_ref = self._strain_dict_name[strain.id][0] - new_aliases = [alias for alias in strain.aliases if alias not in strain_ref.aliases] + new_aliases = [ + alias for alias in strain.aliases + if alias not in strain_ref.aliases + ] for alias in new_aliases: strain_ref.add_alias(alias) if alias not in self._strain_dict_name: @@ -92,14 +106,16 @@ def remove(self, strain: Strain): for name in strain_ref.names: if name in self._strain_dict_name: new_strain_list = [ - s for s in self._strain_dict_name[name] if s.id != strain.id + s for s in self._strain_dict_name[name] + if s.id != strain.id ] if not new_strain_list: del self._strain_dict_name[name] else: self._strain_dict_name[name] = new_strain_list else: - raise ValueError(f"Strain {strain} not found in strain collection.") + raise ValueError( + f"Strain {strain} not found in strain collection.") def filter(self, strain_set: set[Strain]): """ diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 01d5134d..0a4ec786 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -29,6 +29,34 @@ def test_eq(collection: StrainCollection, strain: Strain): assert collection == other +def test_magic_add(collection: StrainCollection, strain: Strain): + other = StrainCollection() + # same id, same alias + other.add(strain) + # same id, different alias + strain1 = Strain("strain_1") + strain1.add_alias("strain_1_b") + other.add(strain1) + # different id, same alias + strain2 = Strain("strain_2") + strain2.add_alias("strain_2_a") + other.add(strain2) + + assert collection + other == other + collection + + actual = collection + other + assert len(actual) == 2 + assert strain in actual + assert strain1 in actual + assert strain2 in actual + assert len(actual._strain_dict_name) == 5 + assert actual._strain_dict_name["strain_1"] == [strain] + assert actual._strain_dict_name["strain_1_a"] == [strain] + assert actual._strain_dict_name["strain_1_b"] == [strain] + assert actual._strain_dict_name["strain_2"] == [strain2] + assert actual._strain_dict_name["strain_2_a"] == [strain2] + + def test_contains(collection: StrainCollection, strain: Strain): assert strain in collection strain2 = Strain("strain_2")