Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

383 get connected subgraphs in an alchemicalnetwork #409

Merged
merged 9 commits into from
Nov 19, 2024
23 changes: 20 additions & 3 deletions gufe/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe

from typing import Iterable, Optional
from typing import Generator, Iterable, Optional
from typing_extensions import Self # Self is included in typing as of python 3.11

import networkx as nx
from .tokenization import GufeTokenizable
Expand All @@ -10,6 +11,7 @@
from .transformations import Transformation



atravitz marked this conversation as resolved.
Show resolved Hide resolved
class AlchemicalNetwork(GufeTokenizable):
_edges: frozenset[Transformation]
_nodes: frozenset[ChemicalSystem]
Expand Down Expand Up @@ -102,7 +104,7 @@ def _to_dict(self) -> dict:
"name": self.name}

@classmethod
def _from_dict(cls, d: dict):
def _from_dict(cls, d: dict) -> Self:
atravitz marked this conversation as resolved.
Show resolved Hide resolved
return cls(nodes=frozenset(d['nodes']),
edges=frozenset(d['edges']),
name=d.get('name'))
Expand All @@ -116,6 +118,21 @@ def to_graphml(self) -> str:
raise NotImplementedError

@classmethod
def from_graphml(cls, str):
def from_graphml(cls, str) -> Self:
"""Currently not implemented"""
raise NotImplementedError

@classmethod
def _from_nx_graph(cls, nx_graph) -> Self:
"""Create an alchemical network from a networkx representation."""
chemical_systems = [n for n in nx_graph.nodes()]
transformations = [e[2]['object'] for e in nx_graph.edges(data=True)]
return cls(nodes=chemical_systems, edges=transformations)

def connected_subgraphs(self) -> Generator[Self, None, None]:
"""Return a generator of all connected subgraphs of the alchemical network."""
node_groups = nx.weakly_connected_components(self.graph)
for node_group in node_groups:
nx_subgraph = self.graph.subgraph(node_group)
alc_subgraph = self._from_nx_graph(nx_subgraph)
yield(alc_subgraph)
19 changes: 13 additions & 6 deletions gufe/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,8 @@ def complex_equilibrium(solvated_complex):
protocol=DummyProtocol(settings=DummyProtocol.default_settings())
)


@pytest.fixture
def benzene_variants_star_map(
def benzene_variants_star_map_transformations(
benzene,
toluene,
phenol,
Expand Down Expand Up @@ -320,7 +319,15 @@ def benzene_variants_star_map(
mapping=None,
)

return gufe.AlchemicalNetwork(
list(solvated_ligand_transformations.values())
+ list(solvated_complex_transformations.values())
)
return list(solvated_ligand_transformations.values()), list(solvated_complex_transformations.values())


@pytest.fixture
def benzene_variants_star_map(benzene_variants_star_map_transformations):
solvated_ligand_transformations, solvated_complex_transformations = benzene_variants_star_map_transformations
return gufe.AlchemicalNetwork(solvated_ligand_transformations+solvated_complex_transformations)

@pytest.fixture
def benzene_variants_ligand_star_map(benzene_variants_star_map_transformations):
solvated_ligand_transformations, _ = benzene_variants_star_map_transformations
return gufe.AlchemicalNetwork(solvated_ligand_transformations)
25 changes: 25 additions & 0 deletions gufe/tests/test_alchemicalnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,28 @@ def test_connectivity(self, benzene_variants_star_map):
else:
edges = alnet.graph.edges(node)
assert len(edges) == 0

def test_connected_subgraphs_multiple_subgraphs(self, benzene_variants_star_map):
"""Identify two separate networks and one floating nodes as subgraphs."""
# remove an edge to create a network w/ two subnetworks and one floating node
edge_list = [e for e in benzene_variants_star_map.edges]
alnet = benzene_variants_star_map.copy_with_replacements(edges=edge_list[:-1])

subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()]

assert set([len(subgraph.nodes) for subgraph in subgraphs]) == {6,7,1}

# which graph has the removed node is not deterministic, so we just
# check that one graph is all-solvent and the other is all-protein
for subgraph in subgraphs:
components = [frozenset(n.components.keys()) for n in subgraph.nodes]
if {'solvent','protein','ligand'} in components:
assert set(components) == {frozenset({'solvent','protein','ligand'})}
else:
assert set(components) == {frozenset({'solvent','ligand'})}

def test_connected_subgraphs_one_subgraph(self, benzene_variants_ligand_star_map):
"""Return the same network if it only contains one connected component."""
alnet = benzene_variants_ligand_star_map
subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()]
assert subgraphs == [alnet]
Loading