Skip to content

Commit

Permalink
adds Add operator to FEMap
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed Dec 7, 2023
1 parent c3662d9 commit de47c51
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
18 changes: 18 additions & 0 deletions cinnabar/femap.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def read_csv(filepath: pathlib.Path, units: Optional[openff.units.Quantity] = No
class FEMap:
"""Free Energy map of both simulations and bench measurements
Contains a set (non-duplicate entries) of different measurements.
Examples
--------
To read from a csv file specifically formatted for this, you can use:
Expand Down Expand Up @@ -120,6 +122,22 @@ def __eq__(self, other):
# iter returns hashable Measurements, so this will compare contents
return set(self) == set(other)

def __add__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
# deduplicate based on hashing the Measurements from iter
my_items = set(self)
other_items = set(other)

new = self.__class__()
for m in my_items | other_items:
new.add_measurement(m)

return new

def __len__(self):
return len(list(iter(self)))

def to_networkx(self) -> nx.MultiDiGraph:
"""A *copy* of the FEMap as a networkx Graph
Expand Down
68 changes: 68 additions & 0 deletions cinnabar/tests/test_femap.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,71 @@ def test_from_networkx(example_map):
m2 = cinnabar.FEMap.from_networkx(g)

assert example_map == m2


def test_add():
m1 = cinnabar.FEMap()
m1.add_experimental_measurement(
label='c1',
value=10.1 * unit.nanomolar,
uncertainty=0.2 * unit.nanomolar,
)
m1.add_experimental_measurement(
label='c2',
value=10.2 * unit.nanomolar,
uncertainty=0.3 * unit.nanomolar,
)

m2 = cinnabar.FEMap()
m2.add_absolute_calculation(
label='c1',
value=-9.5 * unit.kilocalorie_per_mole,
uncertainty=0.4 * unit.kilocalorie_per_mole,
)

m3 = m1 + m2

assert len(m3) == 3
measurements = set(m3)

ref1 = set(m1)
ref2 = set(m2)

assert measurements == ref1 | ref2


def test_add_duplicate():
# adding, but the two maps have a duplicate measurement
m1 = cinnabar.FEMap()
m1.add_experimental_measurement(
label='c1',
value=10.1 * unit.nanomolar,
uncertainty=0.2 * unit.nanomolar,
)
m1.add_experimental_measurement(
label='c2',
value=10.2 * unit.nanomolar,
uncertainty=0.3 * unit.nanomolar,
)

m2 = cinnabar.FEMap()
m2.add_experimental_measurement(
label='c1',
value=10.1 * unit.nanomolar,
uncertainty=0.2 * unit.nanomolar,
)
m2.add_absolute_calculation(
label='c1',
value=-9.5 * unit.kilocalorie_per_mole,
uncertainty=0.4 * unit.kilocalorie_per_mole,
)

m3 = m1 + m2

assert len(m3) == 3
measurements = set(m3)

ref1 = set(m1)
ref2 = set(m2)

assert measurements == ref1 | ref2

0 comments on commit de47c51

Please sign in to comment.