Skip to content

Commit

Permalink
cli: centralise parsing of yaml options and defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
richardjgowers committed Nov 14, 2023
1 parent aab5620 commit 0d57408
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 79 deletions.
36 changes: 5 additions & 31 deletions openfecli/commands/plan_rbfe_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,6 @@ def plan_rbfe_network(

write("Parsing in Files: ")

from gufe import SolventComponent
from openfe.setup.atom_mapping.lomap_scorers import (
default_lomap_score,
)
from openfe.setup import LomapAtomMapper
from openfe.setup.ligand_network_planning import (
generate_minimal_spanning_network,
)

# INPUT
write("\tGot input: ")

Expand All @@ -148,28 +139,11 @@ def plan_rbfe_network(
cofactors = []
write("\t\tCofactors: " + str(cofactors))

# Initially set these to None,
# possible changed via yaml input
# otherwise given default values
mapper_obj = None
mapping_scorer = None
ligand_network_planner = None
solvent = None

if yaml_settings is not None:
yaml_options = YAML_OPTIONS.get(yaml_settings)
mapper_obj = yaml_options.get('mapper', None)
ligand_network_planner = yaml_options.get('network', None)

if mapper_obj is None:
mapper_obj = LomapAtomMapper(time=20, threed=True, element_change=False,
max3d=1)
if mapping_scorer is None:
mapping_scorer = default_lomap_score
if ligand_network_planner is None:
ligand_network_planner = generate_minimal_spanning_network
if solvent is None:
solvent = SolventComponent()
yaml_options = YAML_OPTIONS.get(yaml_settings)
mapper_obj = yaml_options.mapper
mapping_scorer = yaml_options.scorer
ligand_network_planner = yaml_options.ligand_network_planner
solvent = yaml_options.solvent

write("\t\tSolvent: " + str(solvent))
write("")
Expand Down
32 changes: 5 additions & 27 deletions openfecli/commands/plan_rhfe_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,6 @@ def plan_rhfe_network(molecules: List[str], yaml_settings: str, output_dir: str)

write("Parsing in Files: ")

from gufe import SolventComponent
from openfe.setup.atom_mapping.lomap_scorers import (
default_lomap_score,
)
from openfe.setup import LomapAtomMapper
from openfe.setup.ligand_network_planning import (
generate_minimal_spanning_network,
)

# INPUT
write("\tGot input: ")

Expand All @@ -124,24 +115,11 @@ def plan_rhfe_network(molecules: List[str], yaml_settings: str, output_dir: str)
+ " ".join([str(sm) for sm in small_molecules])
)

mapper_obj = None
mapping_scorer = None
ligand_network_planner = None
solvent = None

if yaml_settings is not None:
yaml_options = YAML_OPTIONS.get(yaml_settings)
mapper_obj = yaml_options.get('mapper', None)
ligand_network_planner = yaml_options.get('network', None)

if mapper_obj is None:
mapper_obj = LomapAtomMapper(time=20, threed=True, element_change=False, max3d=1)
if mapping_scorer is None:
mapping_scorer = default_lomap_score
if ligand_network_planner is None:
ligand_network_planner = generate_minimal_spanning_network
if solvent is None:
solvent = SolventComponent()
yaml_options = YAML_OPTIONS.get(yaml_settings)
mapper_obj = yaml_options.mapper
mapping_scorer = yaml_options.scorer
ligand_network_planner = yaml_options.ligand_network_planner
solvent = yaml_options.solvent

write("\t\tSolvent: " + str(solvent))
write("")
Expand Down
70 changes: 49 additions & 21 deletions openfecli/parameters/plan_network_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import click
from collections import namedtuple
try:
# todo; once we're fully v2, we can use ConfigDict not nested class
from pydantic.v1 import BaseModel # , ConfigDict
Expand All @@ -15,13 +16,18 @@
import warnings


PlanNetworkOptions = namedtuple('PlanNetworkOptions',
['mapper', 'scorer',
'ligand_network_planner', 'solvent'])


class MapperSelection(BaseModel):
# model_config = ConfigDict(extra='allow', str_to_lower=True)
class Config:
extra = 'allow'
anystr_lower = True

method: str = 'LomapAtomMapper'
method: Optional[str] = None
settings: dict[str, Any] = {}


Expand All @@ -31,11 +37,11 @@ class Config:
extra = 'allow'
anystr_lower = True

method: str = 'generate_minimal_spanning_network'
method: Optional[str] = None
settings: dict[str, Any] = {}


class CliOptions(BaseModel):
class CliYaml(BaseModel):
# model_config = ConfigDict(extra='allow')
class Config:
extra = 'allow'
Expand All @@ -44,7 +50,7 @@ class Config:
network: Optional[NetworkSelection] = None


def parse_yaml_planner_options(contents: str) -> CliOptions:
def parse_yaml_planner_options(contents: str) -> CliYaml:
"""Parse and minimally validate a user provided yaml
Parameters
Expand Down Expand Up @@ -72,10 +78,10 @@ def parse_yaml_planner_options(contents: str) -> CliOptions:
continue
warnings.warn(f"Ignoring unexpected section: '{field}'")

return CliOptions(**raw)
return CliYaml(**raw)


def load_yaml_planner_options(path: str, context) -> dict:
def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOptions:
"""Load cli options from yaml file path and resolve these to objects
Parameters
Expand All @@ -87,12 +93,12 @@ def load_yaml_planner_options(path: str, context) -> dict:
Returns
-------
options : dict
dict optionally containing 'mapper' and 'network' keys:
'mapper' key holds a AtomMapper object.
'network' key holds a curried network planner function, whose signature
matches generate_minimum_spanning_network.
PlanNetworkOptions : namedtuple
a namedtuple with fields 'mapper', 'scorer', 'network_planning_algorithm',
and 'solvent' fields.
these fields each hold appropriate objects ready for use
"""
from gufe import SolventComponent
from openfe.setup.ligand_network_planning import (
generate_radial_network,
generate_minimal_spanning_network,
Expand All @@ -102,16 +108,22 @@ def load_yaml_planner_options(path: str, context) -> dict:
from openfe.setup import (
LomapAtomMapper,
)
from openfe.setup.atom_mapping.lomap_scorers import (
default_lomap_score,
)
from functools import partial

with open(path, 'r') as f:
raw = f.read()

opt = parse_yaml_planner_options(raw)
if path is not None:
with open(path, 'r') as f:
raw = f.read()

choices = {}
# convert raw yaml to normalised pydantic model
opt = parse_yaml_planner_options(raw)
else:
opt = None

if opt.mapper:
# convert normalised inputs to objects
if opt and opt.mapper:
mapper_choices = {
'lomap': LomapAtomMapper,
'lomapatommapper': LomapAtomMapper,
Expand All @@ -121,9 +133,15 @@ def load_yaml_planner_options(path: str, context) -> dict:
cls = mapper_choices[opt.mapper.method]
except KeyError:
raise KeyError(f"Bad mapper choice: '{opt.mapper.method}'")
mapper_obj = cls(**opt.mapper.settings)
else:
mapper_obj = LomapAtomMapper(time=20, threed=True, element_change=False,
max3d=1)

choices['mapper'] = cls(**opt.mapper.settings)
if opt.network:
# todo: choice of scorer goes here
mapping_scorer = default_lomap_score

if opt and opt.network:
network_choices = {
'generate_radial_network': generate_radial_network,
'radial': generate_radial_network,
Expand All @@ -138,9 +156,19 @@ def load_yaml_planner_options(path: str, context) -> dict:
except KeyError:
raise KeyError(f"Bad network algorithm choice: '{opt.network.method}'")

choices['network'] = partial(func, **opt.network.settings)
ligand_network_planner = partial(func, **opt.network.settings)
else:
ligand_network_planner = generate_minimal_spanning_network

# todo: choice of solvent goes here
solvent = SolventComponent()

return choices
return PlanNetworkOptions(
mapper_obj,
mapping_scorer,
ligand_network_planner,
solvent,
)


YAML_OPTIONS = Option(
Expand Down

0 comments on commit 0d57408

Please sign in to comment.