Skip to content

Commit

Permalink
Merge pull request #173 from NPLinker/add_magic_method_add
Browse files Browse the repository at this point in the history
Add magic method `__add__`
  • Loading branch information
CunliangGeng authored Oct 31, 2023
2 parents 8b8723e + 62291dd commit 6c994b1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import json
from os import PathLike
from typing import Iterator
Expand Down Expand Up @@ -37,6 +38,16 @@ def __eq__(self, other) -> bool:
and self._strain_dict_name == other._strain_dict_name)
return NotImplemented

def __add__(self, other) -> StrainCollection:
if isinstance(other, StrainCollection):
sc = StrainCollection()
for strain in self._strains:
sc.add(strain)
for strain in other._strains:
sc.add(strain)
return sc
return NotImplemented

def __contains__(self, item: Strain) -> bool:
"""Check if the strain collection contains the given Strain object.
"""
Expand All @@ -58,7 +69,10 @@ def add(self, strain: Strain) -> None:
if strain in self._strains:
# only one strain object per id
strain_ref = self._strain_dict_name[strain.id][0]
new_aliases = [alias for alias in strain.aliases if alias not in strain_ref.aliases]
new_aliases = [
alias for alias in strain.aliases
if alias not in strain_ref.aliases
]
for alias in new_aliases:
strain_ref.add_alias(alias)
if alias not in self._strain_dict_name:
Expand Down Expand Up @@ -92,14 +106,16 @@ def remove(self, strain: Strain):
for name in strain_ref.names:
if name in self._strain_dict_name:
new_strain_list = [
s for s in self._strain_dict_name[name] if s.id != strain.id
s for s in self._strain_dict_name[name]
if s.id != strain.id
]
if not new_strain_list:
del self._strain_dict_name[name]
else:
self._strain_dict_name[name] = new_strain_list
else:
raise ValueError(f"Strain {strain} not found in strain collection.")
raise ValueError(
f"Strain {strain} not found in strain collection.")

def filter(self, strain_set: set[Strain]):
"""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ def test_eq(collection: StrainCollection, strain: Strain):
assert collection == other


def test_magic_add(collection: StrainCollection, strain: Strain):
other = StrainCollection()
# same id, same alias
other.add(strain)
# same id, different alias
strain1 = Strain("strain_1")
strain1.add_alias("strain_1_b")
other.add(strain1)
# different id, same alias
strain2 = Strain("strain_2")
strain2.add_alias("strain_2_a")
other.add(strain2)

assert collection + other == other + collection

actual = collection + other
assert len(actual) == 2
assert strain in actual
assert strain1 in actual
assert strain2 in actual
assert len(actual._strain_dict_name) == 5
assert actual._strain_dict_name["strain_1"] == [strain]
assert actual._strain_dict_name["strain_1_a"] == [strain]
assert actual._strain_dict_name["strain_1_b"] == [strain]
assert actual._strain_dict_name["strain_2"] == [strain2]
assert actual._strain_dict_name["strain_2_a"] == [strain2]


def test_contains(collection: StrainCollection, strain: Strain):
assert strain in collection
strain2 = Strain("strain_2")
Expand Down

0 comments on commit 6c994b1

Please sign in to comment.