From 0815f4865f1c8ae81027056db667ada0aee9c2e8 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Mon, 30 Oct 2023 13:37:20 +0100 Subject: [PATCH 1/3] format `strain_collection.py` --- src/nplinker/strain_collection.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index f2f83c3c..a210d431 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -58,7 +58,10 @@ def add(self, strain: Strain) -> None: if strain in self._strains: # only one strain object per id strain_ref = self._strain_dict_name[strain.id][0] - new_aliases = [alias for alias in strain.aliases if alias not in strain_ref.aliases] + new_aliases = [ + alias for alias in strain.aliases + if alias not in strain_ref.aliases + ] for alias in new_aliases: strain_ref.add_alias(alias) if alias not in self._strain_dict_name: @@ -92,14 +95,16 @@ def remove(self, strain: Strain): for name in strain_ref.names: if name in self._strain_dict_name: new_strain_list = [ - s for s in self._strain_dict_name[name] if s.id != strain.id + s for s in self._strain_dict_name[name] + if s.id != strain.id ] if not new_strain_list: del self._strain_dict_name[name] else: self._strain_dict_name[name] = new_strain_list else: - raise ValueError(f"Strain {strain} not found in strain collection.") + raise ValueError( + f"Strain {strain} not found in strain collection.") def filter(self, strain_set: set[Strain]): """ From e53ba6e0268cd44883b60c2ba8649d8867925f8f Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Mon, 30 Oct 2023 13:37:01 +0100 Subject: [PATCH 2/3] add `__add__` method to `StrainCollection` class - add the `__add__` magic method - add unit tests for this magic method --- src/nplinker/strain_collection.py | 10 ++++++++++ tests/test_strain_collection.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index a210d431..1f430d4d 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -37,6 +37,16 @@ def __eq__(self, other) -> bool: and self._strain_dict_name == other._strain_dict_name) return NotImplemented + def __add__(self, other) -> 'StrainCollection': + if isinstance(other, StrainCollection): + sc = StrainCollection() + for strain in self._strains: + sc.add(strain) + for strain in other._strains: + sc.add(strain) + return sc + return NotImplemented + def __contains__(self, item: Strain) -> bool: """Check if the strain collection contains the given Strain object. """ diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 01d5134d..0a4ec786 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -29,6 +29,34 @@ def test_eq(collection: StrainCollection, strain: Strain): assert collection == other +def test_magic_add(collection: StrainCollection, strain: Strain): + other = StrainCollection() + # same id, same alias + other.add(strain) + # same id, different alias + strain1 = Strain("strain_1") + strain1.add_alias("strain_1_b") + other.add(strain1) + # different id, same alias + strain2 = Strain("strain_2") + strain2.add_alias("strain_2_a") + other.add(strain2) + + assert collection + other == other + collection + + actual = collection + other + assert len(actual) == 2 + assert strain in actual + assert strain1 in actual + assert strain2 in actual + assert len(actual._strain_dict_name) == 5 + assert actual._strain_dict_name["strain_1"] == [strain] + assert actual._strain_dict_name["strain_1_a"] == [strain] + assert actual._strain_dict_name["strain_1_b"] == [strain] + assert actual._strain_dict_name["strain_2"] == [strain2] + assert actual._strain_dict_name["strain_2_a"] == [strain2] + + def test_contains(collection: StrainCollection, strain: Strain): assert strain in collection strain2 = Strain("strain_2") From 62291ddf82acd13b1872f6894df27c1c2d7f8e0c Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Tue, 31 Oct 2023 15:38:44 +0100 Subject: [PATCH 3/3] update type hints for class itself When using `from __future__ import annotations`, the class name can be directly used as type hint; Otherwise, use string like "ClassName" as type hint. --- src/nplinker/strain_collection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 1f430d4d..86993f6d 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json from os import PathLike from typing import Iterator @@ -37,7 +38,7 @@ def __eq__(self, other) -> bool: and self._strain_dict_name == other._strain_dict_name) return NotImplemented - def __add__(self, other) -> 'StrainCollection': + def __add__(self, other) -> StrainCollection: if isinstance(other, StrainCollection): sc = StrainCollection() for strain in self._strains: