Skip to content

Commit

Permalink
Merge pull request #267 from Jhsmit/cli_process
Browse files Browse the repository at this point in the history
Jobfiles and process from the command line
  • Loading branch information
Jhsmit authored Apr 8, 2022
2 parents d949888 + de28097 commit 89edb87
Show file tree
Hide file tree
Showing 13 changed files with 294 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Generate conda requirements files `from setup.cfg`:
$ python _requirements.py
First, if you would like a specific PyTorch version to use with PyHDX (ie CUDA/ROCm support), you should install this first.
If you would like a specific PyTorch version to use with PyHDX (ie CUDA/ROCm support), you should install this first.
Installation instructions are on the Pytorch_ website.

Then, install the other base dependencies and optional extras. For example, to install PyHDX with web app:
Expand Down
190 changes: 178 additions & 12 deletions pyhdx/batch_processing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import warnings
from functools import reduce
from pathlib import Path
import os
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
from pyhdx.fileIO import read_dynamx
import re

from pyhdx import TorchFitResult
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, save_fitresult
from pyhdx.fitting import fit_rates_half_time_interpolate, fit_rates_weighted_average, \
fit_gibbs_global, fit_gibbs_global_batch, RatesFitResult, GenericFitResult
import param
import pandas as pd
from pyhdx.support import gen_subclasses
import yaml

time_factors = {"s": 1, "m": 60.0, "min": 60.0, "h": 3600, "d": 86400}
temperature_offsets = {"c": 273.15, "celsius": 273.15, "k": 0, "kelvin": 0}

# todo add data filters in yaml spec

# todo add data filters in state spec?
# todo add proline, n_term options
class YamlParser(object):
""'object used to parse yaml data input files into PyHDX HDX Measurement object'
class StateParser(object):
""'object used to parse yaml state input files into PyHDX HDX Measurement object'

def __init__(self, yaml_dict, data_src=None, data_filters=None):
self.yaml_dict = yaml_dict
# todo yaml_dict -> state_spec
def __init__(self, state_spec, data_src=None, data_filters=None):
self.state_spec = state_spec
if isinstance(data_src, (os.PathLike, str)):
self.data_src = Path(data_src)
elif isinstance(data_src, dict):
Expand Down Expand Up @@ -44,7 +55,7 @@ def load_data(self, *filenames, reader='dynamx'):
def load_hdxmset(self):
"""batch read the full yaml spec into a hdxmeasurementset"""
hdxm_list = []
for state in self.yaml_dict.keys():
for state in self.state_spec.keys():
hdxm = self.load_hdxm(state, name=state)
hdxm_list.append(hdxm)

Expand All @@ -55,7 +66,7 @@ def load_hdxm(self, state, **kwargs):
kwargs: additional kwargs passed to hdxmeasurementset
"""

state_dict = self.yaml_dict[state]
state_dict = self.state_spec[state]

filenames = state_dict["filenames"]
df = self.load_data(*filenames)
Expand Down Expand Up @@ -95,8 +106,8 @@ def load_hdxm(self, state, **kwargs):
raise ValueError("Must specify either 'c_term' or 'sequence'")

state_data = pmt.get_state(state_dict["state"])
for filter in self.data_filters:
state_data = filter(state_data)
for flt in self.data_filters:
state_data = flt(state_data)

hdxm = HDXMeasurement(
state_data,
Expand All @@ -111,16 +122,169 @@ def load_hdxm(self, state, **kwargs):
return hdxm


process_functions = {
'csv_to_dataframe': csv_to_dataframe,
'fit_rates_half_time_interpolate': fit_rates_half_time_interpolate,
'fit_rates_weighted_average': fit_rates_weighted_average,
'fit_gibbs_global': fit_gibbs_global

}

# task objects should be param
class Task(param.Parameterized):
...

scheduler_address = param.String(doc='Optional scheduler adress for dask task')

cwd = param.ClassSelector(Path, doc='Path of the current working directory')


class LoadHDMeasurementSetTask(Task):
_type = 'load_hdxm_set'

state_file = param.String() # = string path

out = param.ClassSelector(HDXMeasurementSet)

def execute(self, *args, **kwargs):
state_spec = yaml.safe_load((self.cwd / self.state_file).read_text())
parser = StateParser(state_spec, self.cwd, default_filters)
hdxm_set = parser.load_hdxmset()

self.out = hdxm_set


class EstimateRates(Task):
_type = 'estimate_rates'

hdxm_set = param.ClassSelector(HDXMeasurementSet)

select_state = param.String(doc='If set, only use this state for creating initial guesses')

out = param.ClassSelector((RatesFitResult, GenericFitResult))

def execute(self, *args, **kwargs):
if self.select_state: # refactor to 'state' ?
hdxm = self.hdxm_set.get(self.select_state)
result = fit_rates_half_time_interpolate(hdxm)
else:
results = []
for hdxm in self.hdxm_set:
r = fit_rates_half_time_interpolate(hdxm)
results.append(r)
result = RatesFitResult(results)

self.out = result


# todo allow guesses from deltaG
class ProcessGuesses(Task):
_type = 'create_guess'

hdxm_set = param.ClassSelector(HDXMeasurementSet)

select_state = param.String(doc='If set, only use this state for creating initial guesses')

rates_df = param.ClassSelector(pd.DataFrame)

out = param.ClassSelector((pd.Series, pd.DataFrame))

def execute(self, *args, **kwargs):
if self.select_state:
hdxm = self.hdxm_set.get(self.select_state)
if self.rates_df.columns.nlevels == 2:
rates_series = self.rates_df[(self.select_state, 'rate')]
else:
rates_series = self.rates_df['rate']

guess = hdxm.guess_deltaG(rates_series)

else:
rates = self.rates_df.xs('rate', level=-1, axis=1)
guess = self.hdxm_set.guess_deltaG(rates)

self.out = guess


class FitGlobalBatch(Task):
_type = 'fit_global_batch'

hdxm_set = param.ClassSelector(HDXMeasurementSet)

initial_guess = param.ClassSelector(
(pd.Series, pd.DataFrame), doc='Initial guesses for fits')

out = param.ClassSelector(TorchFitResult)

def execute(self, *args, **kwargs):
result = fit_gibbs_global_batch(self.hdxm_set, self.initial_guess, **kwargs)

self.out = result


class SaveFitResult(Task):
_type = 'save_fit_result'

fit_result = param.ClassSelector(TorchFitResult)

output_dir = param.String()

def execute(self, *args, **kwargs):
save_fitresult(self.cwd / self.output_dir, self.fit_result)


class JobParser(object):

cwd = param.ClassSelector(Path, doc='Path of the current working directory')

def __init__(self, job_spec, cwd=None):
self.job_spec = job_spec
self.cwd = cwd or Path().cwd()

self.tasks = {}
self.task_classes = {cls._type: cls for cls in gen_subclasses(Task) if getattr(cls, "_type", None)}

def resolve_var(self, var_string):
task_name, *attrs = var_string.split('.')

return reduce(getattr, attrs, self.tasks[task_name])

def execute(self):

for task_spec in self.job_spec['steps']:
task_klass = self.task_classes[task_spec['task']]
skip = {'args', 'kwargs', 'task'}

resolved_params = {}
for par_name in task_spec.keys() - skip:
value = task_spec[par_name]
if isinstance(value, str):
m = re.findall(r'\$\((.*?)\)', value)
if m:
value = self.resolve_var(m[0])
resolved_params[par_name] = value
task = task_klass(cwd=self.cwd, **resolved_params)
task.execute(*task_spec.get('args', []), **task_spec.get('kwargs', {}))

self.tasks[task.name] = task


def yaml_to_hdxmset(yaml_dict, data_dir=None, **kwargs):
"""reads files according to `yaml_dict` spec from `data_dir into HDXMEasurementSet"""

warnings.warn("yaml_to_hdxmset is deprecated, use 'StateParser'")
hdxm_list = []
for k, v in yaml_dict.items():
hdxm = yaml_to_hdxm(v, data_dir=data_dir, name=k)
hdxm_list.append(hdxm)

return HDXMeasurementSet(hdxm_list)

# todo configurable
default_filters = [
lambda df: df.query('exposure > 0')
]


def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
# todo perhas classmethod on HDXMeasurement object?
Expand All @@ -142,7 +306,7 @@ def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
Output data object as specified by `yaml_dict`.
"""

warnings.warn('This method is deprecated in favor of YamlParser', DeprecationWarning)
warnings.warn('This method is deprecated in favor of StateParser', DeprecationWarning)

if data_dir is not None:
input_files = [Path(data_dir) / fname for fname in yaml_dict["filenames"]]
Expand Down Expand Up @@ -270,3 +434,5 @@ def load_from_yaml_v040b2(yaml_dict, data_dir=None, **kwargs): # pragma: no cov
)

return hdxm


58 changes: 34 additions & 24 deletions pyhdx/cli.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
import argparse
import time
from ipaddress import ip_address
from pyhdx.web import serve
from pyhdx.config import cfg
from pyhdx.local_cluster import verify_cluster, default_cluster
from typing import Union, Optional
from pathlib import Path

import typer
from ipaddress import ip_address
import yaml

# todo add check to see if the web module requirements are installed

app = typer.Typer()

def main():
parser = argparse.ArgumentParser(prog="pyhdx", description="PyHDX Launcher")
@app.command()
def serve(scheduler_address: Optional[str] = typer.Option(None, help="Address for dask scheduler to use")):
"""Launch the PyHDX web application"""

parser.add_argument("serve", help="Runs PyHDX Dashboard")
parser.add_argument(
"--scheduler_address", help="Run with local cluster <ip>:<port>"
)
args = parser.parse_args()
from pyhdx.config import cfg
from pyhdx.local_cluster import verify_cluster, default_cluster

if args.scheduler_address:
ip, port = args.scheduler_address.split(":")
if scheduler_address is not None:
ip, port = scheduler_address.split(":")
if not ip_address(ip):
print("Invalid IP Address")
return
elif not 0 <= int(port) < 2 ** 16:
print("Invalid port, must be 0-65535")
return
cfg.set("cluster", "scheduler_address", args.scheduler_address)
cfg.set("cluster", "scheduler_address", scheduler_address)

scheduler_address = cfg.get("cluster", "scheduler_address")
if not verify_cluster(scheduler_address):
Expand All @@ -37,8 +35,9 @@ def main():
scheduler_address = f"{ip}:{port}"
print(f"Started new Dask LocalCluster at {scheduler_address}")

if args.serve:
serve.run_apps()
# Start the PyHDX web application
from pyhdx.web import serve as serve_pyhdx
serve_pyhdx.run_apps()

loop = True
while loop:
Expand All @@ -49,11 +48,22 @@ def main():
loop = False


if __name__ == "__main__":
import sys
@app.command()
def process(
jobfile: Path = typer.Argument(..., help="Path to .yaml jobfile"),
cwd: Optional[Path] = typer.Option(None, help="Optional path to working directory")
):
"""
Process a HDX dataset according to a jobfile
"""

from pyhdx.batch_processing import JobParser

sys.argv.append("serve")
sys.argv.append("--scheduler_address")
sys.argv.append("127.0.0.1:53270")
job_spec = yaml.safe_load(jobfile.read_text())
parser = JobParser(job_spec, cwd=cwd)

main()
parser.execute()


if __name__ == "__main__":
app()
2 changes: 1 addition & 1 deletion pyhdx/fileIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def save_fitresult(output_dir, fit_result, log_lines=None):
dataframe_to_file(output_dir / "losses.csv", fit_result.losses)
dataframe_to_file(output_dir / "losses.txt", fit_result.losses, fmt="pprint")

if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement):
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement): # check, but this should always be hdxm_set
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurement.csv")
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurementSet):
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurements.csv")
Expand Down
1 change: 1 addition & 0 deletions pyhdx/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ class GenericFitResult:

@dataclass
class RatesFitResult:
"""Accumulates multiple Generic/KineticsFit Results"""
results: list

@property
Expand Down
6 changes: 6 additions & 0 deletions pyhdx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,12 @@ def __iter__(self):
def __getitem__(self, item):
return self.hdxm_list.__getitem__(item)

def get(self, name):
"""find a HDXMeasurement by name"""

idx = self.names.index(name)
return self[idx]

@property
def Ns(self):
return len(self.hdxm_list)
Expand Down
4 changes: 2 additions & 2 deletions pyhdx/web/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from proplot import to_hex
from skimage.filters import threshold_multiotsu

from pyhdx.batch_processing import YamlParser
from pyhdx.batch_processing import StateParser
from pyhdx.config import cfg
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, dataframe_to_stringio
from pyhdx.fitting import (
Expand Down Expand Up @@ -499,7 +499,7 @@ def _add_dataset_batch(self):
ios = {name: StringIO(byte_content.decode("UTF-8")) for name, byte_content in zip(self.widgets['input_files'].filename, self.input_files)}
filters = [lambda df: df.query('exposure > 0')]

parser = YamlParser(yaml_dict, data_src=ios, data_filters=filters)
parser = StateParser(yaml_dict, data_src=ios, data_filters=filters)

for state in yaml_dict.keys():
hdxm = parser.load_hdxm(state, name=state)
Expand Down
Loading

0 comments on commit 89edb87

Please sign in to comment.