Skip to content

Commit

Permalink
starts hiding internal implementation of network
Browse files Browse the repository at this point in the history
instead representations are accessed via to/from methods on FEMap class

adds to/from networkx representation

adds FEMap eq and iter magic methods
  • Loading branch information
richardjgowers committed Nov 20, 2023
1 parent 0a88952 commit 49a683c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 21 deletions.
83 changes: 67 additions & 16 deletions cinnabar/femap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
from typing import Union

import copy
import openff.units
import pandas as pd
from openff.units import unit
Expand Down Expand Up @@ -95,13 +95,64 @@ class FEMap:
>>> fe.add_measurement(experimental_result2)
>>> fe.add_measurement(calculated_result)
"""
# internal representation:
# graph with measurements as edges
# absolute Measurements are an edge between 'ReferenceState' and the label
# all edges are directed, all edges can be multiply defined
graph: nx.MultiDiGraph
# all edges are directed
# all edges can be multiply defined
_graph: nx.MultiDiGraph

def __init__(self):
self.graph = nx.MultiDiGraph()
self._graph = nx.MultiDiGraph()

def __iter__(self):
for a, b, d in self._graph.edges(data=True):
# skip artificial reverse edges
if d['source'] == 'reverse':
continue

yield Measurement(labelA=a, labelB=b, **d)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented

# iter returns hashable Measurements, so this will compare contents
return set(self) == set(other)

def to_networkx(self) -> nx.MultiDiGraph:
"""A *copy* of the FEMap as a networkx Graph
The FEMap is represented as a multi-edged directional graph
Edges have the following attributes:
- DG: the free energy difference of going from the first edge label to
the second edge label
- uncertainty: uncertainty of the DG value
- temperature: the temperature at which DG was measured
- computational: boolean label of the original source of the data
- source: a string describing the source of data.
Note
----
All edges appear twice, once with the attribute source='reverse',
and the DG value flipped. This allows "pathfinding" like approaches,
where the DG values will be correctly summed.
"""
return copy.deepcopy(self._graph)

@classmethod
def from_networkx(cls, graph: nx.MultiDiGraph):
"""Create FEMap from network representation
Note
----
Currently absolutely no validation of the input is done.
"""
m = cls()
m._graph = graph

return m

@classmethod
def from_csv(cls, filename, units: Optional[unit.Quantity] = None):
Expand Down Expand Up @@ -133,8 +184,8 @@ def add_measurement(self, measurement: Measurement):

# add both directions, but flip sign for the other direction
d_backwards = {**d, 'DG': - d['DG'], 'source': 'reverse'}
self.graph.add_edge(measurement.labelA, measurement.labelB, **d)
self.graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)
self._graph.add_edge(measurement.labelA, measurement.labelB, **d)
self._graph.add_edge(measurement.labelB, measurement.labelA, **d_backwards)

def add_experimental_measurement(self,
label: Union[str, Hashable],
Expand Down Expand Up @@ -262,7 +313,7 @@ def get_relative_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if isinstance(l1, ReferenceState) or isinstance(l2, ReferenceState):
Expand Down Expand Up @@ -297,7 +348,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
"""
kcpm = unit.kilocalorie_per_mole
data = []
for l1, l2, d in self.graph.edges(data=True):
for l1, l2, d in self._graph.edges(data=True):
if d['source'] == 'reverse':
continue
if not isinstance(l1, ReferenceState):
Expand Down Expand Up @@ -325,7 +376,7 @@ def get_absolute_dataframe(self) -> pd.DataFrame:
@property
def n_measurements(self) -> int:
"""Total number of both experimental and computational measurements"""
return len(self.graph.edges) // 2
return len(self._graph.edges) // 2

@property
def n_ligands(self) -> int:
Expand All @@ -336,7 +387,7 @@ def n_ligands(self) -> int:
def ligands(self) -> list:
"""All ligands in the graph"""
# must ignore ReferenceState nodes
return [n for n in self.graph.nodes
return [n for n in self._graph.nodes
if not isinstance(n, ReferenceState)]

@property
Expand All @@ -347,14 +398,14 @@ def degree(self) -> float:
@property
def n_edges(self) -> int:
"""Number of computational edges"""
return sum(1 for _, _, d in self.graph.edges(data=True)
return sum(1 for _, _, d in self._graph.edges(data=True)
if d['computational']) // 2

def check_weakly_connected(self) -> bool:
"""Checks if all results in the graph are reachable from other results"""
# todo; cache
comp_graph = nx.MultiGraph()
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_graph.add_edge(a, b)
Expand All @@ -365,7 +416,7 @@ def generate_absolute_values(self):
"""Populate the FEMap with absolute computational values based on MLE"""
# TODO: Make this return a new Graph with computational nodes annotated with DG values
# TODO this could work if either relative or absolute expt values are provided
mes = list(self.graph.edges(data=True))
mes = list(self._graph.edges(data=True))
# for now, we must all be in the same units for this to work
# grab unit of first measurement
u = mes[0][-1]['DG'].u
Expand Down Expand Up @@ -394,7 +445,7 @@ def generate_absolute_values(self):

# find all computational result labels
comp_ligands = set()
for A, B, d in self.graph.edges(data=True):
for A, B, d in self._graph.edges(data=True):
if not d['computational']:
continue
comp_ligands.add(A)
Expand Down Expand Up @@ -433,7 +484,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
# reduces to nx.DiGraph
g = nx.DiGraph()
# add DDG values from computational graph
for a, b, d in self.graph.edges(data=True):
for a, b, d in self._graph.edges(data=True):
if not d['computational']:
continue
if isinstance(a, ReferenceState): # skip absolute measurements
Expand All @@ -444,7 +495,7 @@ def to_legacy_graph(self) -> nx.DiGraph:
g.add_edge(a, b, calc_DDG=d['DG'].magnitude, calc_dDDG=d['uncertainty'].magnitude)
# add DG values from experiment graph
for node, d in g.nodes(data=True):
expt = self.graph.get_edge_data(ReferenceState(), node)
expt = self._graph.get_edge_data(ReferenceState(), node)
if expt is None:
continue
expt = expt[0]
Expand Down
24 changes: 19 additions & 5 deletions cinnabar/tests/test_femap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def example_map(example_csv):
def test_from_csv(example_map):
assert example_map.n_ligands == 36
assert example_map.n_edges == 58
assert len(example_map.graph.edges) == (58 + 36) * 2
assert len(example_map._graph.edges) == (58 + 36) * 2


def test_eq(example_csv):
m1 = cinnabar.FEMap.from_csv(example_csv)
m2 = cinnabar.FEMap.from_csv(example_csv)
m3 = cinnabar.FEMap.from_csv(example_csv)
m3.add_experimental_measurement(
label='this',
value=4.2 * unit.kilocalorie_per_mole,
uncertainty=0.1 * unit.kilocalorie_per_mole,
)

assert m1 == m2
assert m1 != m3


def test_degree(example_map):
Expand Down Expand Up @@ -77,7 +91,7 @@ def test_femap_add_experimental(ki):
)

assert set(m.ligands) == {'ligA'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'ligA')
assert d.keys() == {0}
d = d[0]
assert d['computational'] is False
Expand Down Expand Up @@ -118,7 +132,7 @@ def test_add_ABFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo'}
d = m.graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
d = m._graph.get_edge_data(cinnabar.ReferenceState(), 'foo')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
Expand All @@ -143,7 +157,7 @@ def test_add_RBFE(default_T):
source='ebay', temperature=T)

assert set(m.ligands) == {'foo', 'bar'}
d = m.graph.get_edge_data('foo', 'bar')
d = m._graph.get_edge_data('foo', 'bar')
assert len(d) == 1
d = d[0]
assert d['DG'] == v
Expand All @@ -166,7 +180,7 @@ def test_generate_absolute_values(example_map, ref_mle_results):
example_map.generate_absolute_values()

for e, (y_ref, yerr_ref) in ref_mle_results.items():
data = example_map.graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
data = example_map._graph.get_edge_data(cinnabar.ReferenceState(label='MLE'), e)
# grab the dict containing MLE data
for _, d in data.items():
if d['source'] == 'MLE':
Expand Down

0 comments on commit 49a683c

Please sign in to comment.