diff --git a/plasnet/hub_graph.py b/plasnet/hub_graph.py index ee6ef52..b9ded51 100644 --- a/plasnet/hub_graph.py +++ b/plasnet/hub_graph.py @@ -55,14 +55,17 @@ def _get_hub_plasmids(self, use_cached: bool = False) -> list[str]: return self._hub_plasmids - def remove_hub_plasmids(self) -> None: + def remove_hub_plasmids(self) -> set[str]: + all_hub_plasmids = set() while True: hub_plasmids = self._get_hub_plasmids() + all_hub_plasmids.update(hub_plasmids) there_are_still_hub_plasmids = len(hub_plasmids) > 0 if there_are_still_hub_plasmids: self.remove_nodes_from(hub_plasmids) else: break + return all_hub_plasmids def _get_filters_HTML(self) -> str: nb_of_hubs = len(self._get_hub_plasmids(use_cached=True)) diff --git a/plasnet/plasnet_main.py b/plasnet/plasnet_main.py index fa20209..2162d15 100644 --- a/plasnet/plasnet_main.py +++ b/plasnet/plasnet_main.py @@ -205,8 +205,10 @@ def type( logging.info("Typing communities (i.e. splitting them into subcommunities)") all_subcommunities = Subcommunities() + all_hub_plasmids = set() for community in communities: - community.remove_hub_plasmids() + hub_plasmids = community.remove_hub_plasmids() + all_hub_plasmids.update(hub_plasmids) subcommunities = community.split_graph_into_subcommunities( small_subcommunity_size_threshold ) @@ -229,6 +231,10 @@ def type( original_communities.save(objects_dir / "communities.pkl") all_subcommunities.save(objects_dir / "subcommunities.pkl") all_subcommunities.save_classification(objects_dir / "typing.tsv", "plasmid\ttype") + with open(objects_dir / "hub_plasmids.csv", "w") as hub_plasmids_fh: + print("hub_plasmids", file=hub_plasmids_fh) + for plasmid in all_hub_plasmids: + print(plasmid, file=hub_plasmids_fh) logging.info("All done!") diff --git a/tests/data/hub/hub_plasmids.csv b/tests/data/hub/hub_plasmids.csv new file mode 100644 index 0000000..ed9a15e --- /dev/null +++ b/tests/data/hub/hub_plasmids.csv @@ -0,0 +1,4 @@ +hub_plasmids +pKPC_CAV1320 +pCAV1492-6393 +pCAV1374-6538 diff --git a/tests/test_integration_tests.py b/tests/test_integration_tests.py index cf9505e..ba3cb09 100644 --- a/tests/test_integration_tests.py +++ b/tests/test_integration_tests.py @@ -113,3 +113,11 @@ def test_remove_hub_plasmids_iteratively(self) -> None: sort=True, ) ) + + self.assertTrue( + check_if_files_are_equal( + Path("tests/data/hub/out/type_out/objects/hub_plasmids.csv"), + Path("tests/data/hub/hub_plasmids.csv"), + sort=True, + ) + )