Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Connectivity Strategy #1

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[report]
exclude_lines =
@abstractmethod
@abc.abstractmethod
2 changes: 1 addition & 1 deletion devtools/conda-envs/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ dependencies:
- pytest
- pytest-xdist
- pytest-cov
- coverage
- coverage
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Issues = "https://github.com/OpenFreeEnergy/stratocaster/issues"
[project.optional-dependencies]
test = [
"pytest",
"pytest-cov",
]
dev = [
"stratocaster[test]",
Expand Down
5 changes: 2 additions & 3 deletions src/stratocaster/base/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from gufe.settings.models import SettingsBaseModel


# TODO: docstrings
class StrategySettings(SettingsBaseModel):

def __init__(self):
normalize_weights: bool = True
pass
49 changes: 42 additions & 7 deletions src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import abc
from typing import Self
from typing import Self, TypeVar

from gufe.tokenization import GufeTokenizable
from gufe import AlchemicalNetwork
from gufe.protocols import ProtocolResult
from gufe.tokenization import GufeTokenizable, GufeKey
from gufe import AlchemicalNetwork, ProtocolResult

from .models import StrategySettings

_ProtocolResult = TypeVar("_ProtocolResult", bound=ProtocolResult)


# TODO: docstrings
class StrategyResult(GufeTokenizable):

def __init__(self, weights):
def __init__(self, weights: dict[GufeKey, float | None]):
self._weights = weights

@classmethod
Expand All @@ -24,13 +26,46 @@ def _to_dict(self) -> dict:
def _from_dict(cls, dct: dict) -> Self:
return cls(**dct)

@property
def weights(self) -> dict[GufeKey, float | None]:
return self._weights

def resolve(self) -> dict[GufeKey, float | None]:
weights = self.weights
weight_sum = sum([weight for weight in weights.values() if weight is not None])
modified_weights = {
key: weight / weight_sum
for key, weight in weights.items()
if weight is not None
}
weights.update(modified_weights)
return weights


# TODO: docstrings
class Strategy(GufeTokenizable):
"""An object that proposes the relative urgency of computing transformations within an AlchemicalNetwork."""

_settings_cls: type[StrategySettings]

def __init__(self, settings: StrategySettings):

if not hasattr(self.__class__, "_settings_cls"):
raise NotImplementedError(
f"class `{self.__class__.__qualname__}` must implement the `_settings_cls` attribute."
)

if not isinstance(settings, self._settings_cls):
raise ValueError(
f"`{self.__class__.__qualname__}` expected a `{self._settings_cls.__qualname__}` instance"
)

self._settings = settings
super().__init__()

@property
def settings(self) -> StrategySettings:
return self._settings

@classmethod
def _defaults(cls):
Expand All @@ -52,13 +87,13 @@ def _default_settings(cls) -> StrategySettings:
def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: list[ProtocolResult],
protocol_results: dict[GufeKey, _ProtocolResult],
) -> StrategyResult:
raise NotImplementedError

def propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: list[ProtocolResult],
protocol_results: dict[GufeKey, _ProtocolResult],
) -> StrategyResult:
return self._propose(alchemical_network, protocol_results)
3 changes: 3 additions & 0 deletions src/stratocaster/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from stratocaster.strategies.connectivity import ConnectivityStrategy

__all__ = ["ConnectivityStrategy"]
135 changes: 135 additions & 0 deletions src/stratocaster/strategies/connectivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey

from stratocaster.base import Strategy, StrategyResult
from stratocaster.base.models import StrategySettings

from pydantic import validator, Field, root_validator


# TODO: docstrings
class ConnectivityStrategySettings(StrategySettings):

decay_rate: float = Field(
default=0.5, description="decay rate of the exponential decay penalty factor"
)
cutoff: float | None = Field(
default=None,
description="unnormalized weight cutoff used for termination condition",
)
max_runs: int | None = Field(
default=None,
description="the upper limit of protocol DAG results needed before a transformation is no longer weighed",
)

@validator("cutoff")
def validate_cutoff(cls, value):
if value is not None:
if not (0 < value):
raise ValueError("`cutoff` must be greater than 0")
return value

@validator("decay_rate")
def validate_decay_rate(cls, value):
if not (0 < value < 1):
raise ValueError("`decay_rate` must be between 0 and 1")
return value

@validator("max_runs")
def validate_max_runs(cls, value):
if value is not None:
if not value >= 1:
raise ValueError("`max_runs` must be greater than or equal to 1")
return value

@root_validator
def check_cutoff_or_max_runs(cls, values):
max_runs, cutoff = values.get("max_runs"), values.get("cutoff")

if max_runs is None and cutoff is None:
raise ValueError("At least one of `max_runs` or `cutoff` must be set")

return values


# TODO: docstrings
class ConnectivityStrategy(Strategy):

_settings_cls = ConnectivityStrategySettings

def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float):
return decay_rate**number_of_results

def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult],
) -> StrategyResult:
"""Propose `Transformation` weight recommendations based on high connectivity nodes.

Parameters
----------
alchemical_network: AlchemicalNetwork
protocol_results: dict[GufeKey, ProtocolResult]
A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork`
and whose values are the `ProtocolResult`s for those `Transformation`s.

Returns
-------
StrategyResult
A `StrategyResult` containing the proposed `Transformation` weights.
"""

settings = self.settings

# keep the type checker happy
assert isinstance(settings, ConnectivityStrategySettings)

alchemical_network_mdg = alchemical_network.graph
weights: dict[GufeKey, float | None] = {}

for state_a, state_b in alchemical_network_mdg.edges():
num_neighbors_a = alchemical_network_mdg.degree(state_a)
num_neighbors_b = alchemical_network_mdg.degree(state_b)

# linter-satisfying assertion
assert isinstance(num_neighbors_a, int) and isinstance(num_neighbors_b, int)

transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[
0
]["object"].key

match (protocol_results.get(transformation_key)):
case None:
transformation_n_protcol_dag_results = 0
case pr:
assert isinstance(pr, ProtocolResult)
transformation_n_protcol_dag_results = pr.n_protocol_dag_results

scaling_factor = self._exponential_decay_scaling(
transformation_n_protcol_dag_results, settings.decay_rate
)
weight = scaling_factor * (num_neighbors_a + num_neighbors_b) / 2

match (settings.max_runs, settings.cutoff):
case (None, cutoff) if cutoff is not None:
if weight < cutoff:
weight = None
case (max_runs, None) if max_runs is not None:
if transformation_n_protcol_dag_results >= max_runs:
weight = None
case (max_runs, cutoff) if max_runs is not None and cutoff is not None:
if (
weight < cutoff
or transformation_n_protcol_dag_results >= max_runs
):
weight = None

weights[transformation_key] = weight

results = StrategyResult(weights=weights)
return results

@classmethod
def _default_settings(cls) -> StrategySettings:
return ConnectivityStrategySettings(max_runs=3)
Loading