Skip to content

Commit

Permalink
Merge pull request #231 from Jhsmit/cuda
Browse files Browse the repository at this point in the history
Allow CUDA and float dtype
  • Loading branch information
Jhsmit authored Sep 27, 2021
2 parents 24c813c + 6e16ad2 commit 64a63aa
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 36 deletions.
6 changes: 5 additions & 1 deletion pyhdx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from .models import PeptideMasterTable, PeptideMeasurements, HDXMeasurement, Coverage, HDXMeasurementSet
from .fileIO import read_dynamx
from .fitting_torch import TorchSingleFitResult, TorchBatchFitResult
from .output import Output, Report
from ._version import get_versions

try:
from .output import Output, Report
except ModuleNotFoundError:
pass


__version__ = get_versions()['version']

Expand Down
4 changes: 1 addition & 3 deletions pyhdx/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from ipaddress import ip_address
from pyhdx.web import serve
from pyhdx.config import ConfigurationSettings
from pyhdx.config import cfg
from pyhdx.local_cluster import verify_cluster, default_cluster


Expand All @@ -15,8 +15,6 @@ def main():
parser.add_argument('--scheduler_address', help="Run with local cluster <ip>:<port>")
args = parser.parse_args()

cfg = ConfigurationSettings()

if args.scheduler_address:
ip, port = args.scheduler_address.split(':')
if not ip_address(ip):
Expand Down
4 changes: 4 additions & 0 deletions pyhdx/config.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[cluster]
scheduler_address = 127.0.0.1:52123
n_workers = 10

[fitting]
dtype = float64
device = cpu
26 changes: 24 additions & 2 deletions pyhdx/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import configparser
from pathlib import Path
from pyhdx import __version__
from pyhdx._version import get_versions
from packaging import version
import torch
import warnings


__version__ = get_versions()['version']
del get_versions


def read_config(path):
"""read .ini config file at path, return configparser.ConfigParser object"""
config = configparser.ConfigParser()
Expand Down Expand Up @@ -86,6 +91,21 @@ def write_config(self, path=None):
with open(pth, 'w') as config_file:
self._config.write(config_file)

@property
def TORCH_DTYPE(self):
dtype = self.get('fitting', 'dtype')
if dtype in ['float64', 'double']:
return torch.float64
elif dtype in ['float32', 'float']:
return torch.float32
else:
raise ValueError(f'Unsupported data type: {dtype}')

@property
def TORCH_DEVICE(self):
device = self.get('fitting', 'device')
return torch.device(device)


def valid_config():
"""Checks if the current config file in the user home directory is a valid config
Expand All @@ -111,4 +131,6 @@ def valid_config():

config_file_path = config_dir / 'config.ini'
if not valid_config():
reset_config()
reset_config()

cfg = ConfigurationSettings()
9 changes: 5 additions & 4 deletions pyhdx/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from tqdm import trange

from pyhdx.fit_models import SingleKineticModel, TwoComponentAssociationModel, TwoComponentDissociationModel
from pyhdx.fitting_torch import DeltaGFit, TorchSingleFitResult, TorchBatchFitResult, TORCH_DTYPE, TORCH_DEVICE
from pyhdx.models import Protein
from pyhdx.fitting_torch import DeltaGFit, TorchSingleFitResult, TorchBatchFitResult
from pyhdx.support import temporary_seed
from pyhdx.models import Protein
from pyhdx.config import cfg

EmptyResult = namedtuple('EmptyResult', ['chi_squared', 'params'])
er = EmptyResult(np.nan, {k: np.nan for k in ['tau1', 'tau2', 'r']})
Expand Down Expand Up @@ -451,7 +452,7 @@ def fit_gibbs_global(hdxm, initial_guess, r1=R1, epochs=EPOCHS, patience=PATIENC
assert len(initial_guess) == hdxm.Nr, "Invalid length of initial guesses"

dtype = torch.float64
deltaG_par = torch.nn.Parameter(torch.tensor(initial_guess, dtype=TORCH_DTYPE, device=TORCH_DEVICE).unsqueeze(-1)) #reshape (nr, 1)
deltaG_par = torch.nn.Parameter(torch.tensor(initial_guess, dtype=cfg.TORCH_DTYPE, device=cfg.TORCH_DEVICE).unsqueeze(-1)) #reshape (nr, 1)

model = DeltaGFit(deltaG_par)
criterion = torch.nn.MSELoss(reduction='mean')
Expand Down Expand Up @@ -580,7 +581,7 @@ def _batch_fit(hdx_set, initial_guess, reg_func, fit_kwargs, optimizer_kwargs):

assert initial_guess.shape == (hdx_set.Ns, hdx_set.Nr), "Invalid shape of initial guesses"

deltaG_par = torch.nn.Parameter(torch.tensor(initial_guess, dtype=TORCH_DTYPE, device=TORCH_DEVICE).reshape(hdx_set.Ns, hdx_set.Nr, 1))
deltaG_par = torch.nn.Parameter(torch.tensor(initial_guess, dtype=cfg.TORCH_DTYPE, device=cfg.TORCH_DEVICE).reshape(hdx_set.Ns, hdx_set.Nr, 1))

model = DeltaGFit(deltaG_par)
criterion = torch.nn.MSELoss(reduction='mean')
Expand Down
10 changes: 6 additions & 4 deletions pyhdx/fitting_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from pyhdx.fileIO import dataframe_to_file
from pyhdx.models import Protein
from pyhdx.config import cfg

TORCH_DTYPE = t.double
TORCH_DEVICE = t.device('cpu')
# TORCH_DTYPE = t.double
# TORCH_DEVICE = t.device('cpu')

class DeltaGFit(nn.Module):
def __init__(self, deltaG):
Expand Down Expand Up @@ -46,11 +47,12 @@ def estimate_errors(hdxm, deltaG):
-------
"""
dtype = t.float64
joined = pd.concat([deltaG, hdxm.coverage['exchanges']], axis=1, keys=['dG', 'ex'])
dG = joined.query('ex==True')['dG']
deltaG = t.tensor(dG.to_numpy(), dtype=TORCH_DTYPE)
deltaG = t.tensor(dG.to_numpy(), dtype=dtype)

tensors = {k: v.cpu() for k, v in hdxm.get_tensors(exchanges=True).items()}
tensors = {k: v.cpu() for k, v in hdxm.get_tensors(exchanges=True, dtype=dtype).items()}

def hes_loss(deltaG_input):
criterion = t.nn.MSELoss(reduction='sum')
Expand Down
4 changes: 1 addition & 3 deletions pyhdx/local_cluster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from dask.distributed import LocalCluster, Client
import time
from pyhdx.config import ConfigurationSettings
from pyhdx.config import cfg
import argparse

cfg = ConfigurationSettings()

def default_client(timeout='2s'):
"""Return Dask client at scheduler adress as defined by the global config"""
scheduler_address = cfg.get('cluster', 'scheduler_address')
Expand Down
13 changes: 7 additions & 6 deletions pyhdx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyhdx.alignment import align_dataframes
from pyhdx.fileIO import dataframe_to_file
from pyhdx.support import reduce_inter, fields_view
from pyhdx.config import cfg


def protein_wrapper(func, *args, **kwargs):
Expand Down Expand Up @@ -748,7 +749,7 @@ def d_exp(self):
df.columns.name = 'exposure'
return df

def get_tensors(self, exchanges=False):
def get_tensors(self, exchanges=False, dtype=None):
"""
Returns a dictionary of tensor variables for fitting to Linderstrøm-Lang kinetics.
Expand Down Expand Up @@ -784,8 +785,8 @@ def get_tensors(self, exchanges=False):
else:
bools = np.ones(self.Nr, dtype=bool)

dtype = pyhdx.fitting_torch.TORCH_DTYPE
device = pyhdx.fitting_torch.TORCH_DEVICE
dtype = dtype or cfg.TORCH_DTYPE
device = cfg.TORCH_DEVICE

tensors = {
'temperature': torch.tensor([self.temperature], dtype=dtype, device=device).unsqueeze(-1),
Expand Down Expand Up @@ -1130,7 +1131,7 @@ def add_alignment(self, alignment, first_r_numbers=None):

self.aligned_indices = df.to_numpy(dtype=int).T

def get_tensors(self):
def get_tensors(self, dtype=None):
#todo create correct shapes as per table X for all
temperature = np.array([kf.temperature for kf in self.hdxm_list])

Expand All @@ -1142,8 +1143,8 @@ def get_tensors(self):
k_int = np.zeros((self.Ns, self.Nr))
k_int[self.masks['sr']] = k_int_values

dtype = pyhdx.fitting_torch.TORCH_DTYPE
device = pyhdx.fitting_torch.TORCH_DEVICE
dtype = dtype or cfg.TORCH_DTYPE
device = cfg.TORCH_DEVICE

tensors = {
'temperature': torch.tensor(temperature, dtype=dtype, device=device).reshape(self.Ns, 1, 1),
Expand Down
3 changes: 1 addition & 2 deletions pyhdx/web/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import panel as pn
from pyhdx.web.log import logger
from pyhdx.config import ConfigurationSettings
from pyhdx.config import cfg
from pyhdx.local_cluster import default_client

from pathlib import Path
Expand All @@ -27,7 +27,6 @@
current_dir = Path(__file__).parent
data_dir = current_dir.parent.parent / 'tests' / 'test_data'
global_opts = {'show_grid': True}
cfg = ConfigurationSettings()

@logger('pyhdx')
def main_app(client='default'):
Expand Down
4 changes: 2 additions & 2 deletions pyhdx/web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from pyhdx.config import ConfigurationSettings
from pyhdx.config import cfg
from pyhdx.local_cluster import verify_cluster

import logging
Expand All @@ -24,7 +24,7 @@ def run_main():
np.random.seed(43)
torch.manual_seed(43)

scheduler_address = ConfigurationSettings().get('cluster', 'scheduler_address')
scheduler_address = cfg.get('cluster', 'scheduler_address')
if not verify_cluster(scheduler_address):
print(f"No valid Dask scheduler found at specified address: '{scheduler_address}'")
return
Expand Down
66 changes: 57 additions & 9 deletions tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pyhdx.fileIO import read_dynamx, csv_to_protein, csv_to_dataframe, save_fitresult, load_fitresult
from pyhdx.fitting import fit_rates_weighted_average, fit_gibbs_global, fit_gibbs_global_batch, fit_gibbs_global_batch_aligned
from pyhdx.models import HDXMeasurementSet
from pyhdx.config import cfg
import numpy as np
import torch
import time
from dask.distributed import LocalCluster
from pathlib import Path

import pandas as pd
from pandas.testing import assert_series_equal

cwd = Path(__file__).parent
input_dir = cwd / 'test_data' / 'input'
Expand Down Expand Up @@ -52,6 +54,35 @@ def test_initial_guess(self):
# todo additional tests:
# result = fit_rates_half_time_interpolate()

def test_dtype_cuda(self):
check_deltaG = csv_to_protein(output_dir / 'ecSecB_torch_fit.csv')
initial_rates = csv_to_dataframe(output_dir / 'ecSecB_guess.csv')

cfg.set('fitting', 'device', 'cuda')
gibbs_guess = self.hdxm_apo.guess_deltaG(initial_rates['rate']).to_numpy()

if torch.cuda.is_available():
fr_global = fit_gibbs_global(self.hdxm_apo, gibbs_guess, epochs=1000, r1=2)
out_deltaG = fr_global.output
for field in ['deltaG', 'k_obs', 'covariance']:
assert_series_equal(check_deltaG[field], out_deltaG[field], rtol=0.01, check_dtype=False)
else:
with pytest.raises(AssertionError, match=r".* CUDA .*"):
fr_global = fit_gibbs_global(self.hdxm_apo, gibbs_guess, epochs=1000, r1=2)

cfg.set('fitting', 'device', 'cpu')
cfg.set('fitting', 'dtype', 'float32')

fr_global = fit_gibbs_global(self.hdxm_apo, gibbs_guess, epochs=1000, r1=2)
dg = fr_global.model.deltaG
assert dg.dtype == torch.float32

out_deltaG = fr_global.output
for field in ['deltaG', 'k_obs']:
assert_series_equal(check_deltaG[field], out_deltaG[field], rtol=0.01, check_dtype=False)

cfg.set('fitting', 'dtype', 'float64')

def test_global_fit(self):
initial_rates = csv_to_dataframe(output_dir / 'ecSecB_guess.csv')

Expand All @@ -64,33 +95,50 @@ def test_global_fit(self):
out_deltaG = fr_global.output
check_deltaG = csv_to_protein(output_dir / 'ecSecB_torch_fit.csv')

assert np.allclose(check_deltaG['deltaG'], out_deltaG['deltaG'], equal_nan=True, rtol=0.01)
assert np.allclose(check_deltaG['covariance'], out_deltaG['covariance'], equal_nan=True, rtol=0.01)
assert np.allclose(check_deltaG['k_obs'], out_deltaG['k_obs'], equal_nan=True, rtol=0.01)
for field in ['deltaG', 'covariance', 'k_obs']:
assert_series_equal(check_deltaG[field], out_deltaG[field], rtol=0.01)

mse = fr_global.get_mse()
assert mse.shape == (self.hdxm_apo.Np, self.hdxm_apo.Nt)

@pytest.mark.skip(reason="Longer fit is not checked by default due to long computation times")
def test_global_fit_extended(self):
check_deltaG = csv_to_protein(output_dir / 'ecSecB_torch_fit_epochs_20000.csv')
initial_rates = csv_to_dataframe(output_dir / 'ecSecB_guess.csv')
gibbs_guess = self.hdxm_apo.guess_deltaG(initial_rates['rate']).to_numpy()

t0 = time.time() # Very crude benchmarks
gibbs_guess = self.hdxm_apo.guess_deltaG(initial_rates['rate']).to_numpy()
fr_global = fit_gibbs_global(self.hdxm_apo, gibbs_guess, epochs=20000, r1=2)
t1 = time.time()

assert t1 - t0 < 20
out_deltaG = fr_global.output
check_deltaG = csv_to_protein(output_dir / 'ecSecB_torch_fit_epochs_20000.csv')

assert np.allclose(check_deltaG['deltaG'], out_deltaG['deltaG'], equal_nan=True, rtol=0.01)
assert np.allclose(check_deltaG['covariance'], out_deltaG['covariance'], equal_nan=True, rtol=0.01)
assert np.allclose(check_deltaG['k_obs'], out_deltaG['k_obs'], equal_nan=True, rtol=0.01)
for field in ['deltaG', 'k_obs', 'covariance']:
assert_series_equal(check_deltaG[field], out_deltaG[field], rtol=0.01, check_dtype=False)

mse = fr_global.get_mse()
assert mse.shape == (self.hdxm_apo.Np, self.hdxm_apo.Nt)

@pytest.mark.skip(reason="Longer fit is not checked by default due to long computation times")
def test_global_fit_extended_cuda(self):
check_deltaG = csv_to_protein(output_dir / 'ecSecB_torch_fit_epochs_20000.csv')
initial_rates = csv_to_dataframe(output_dir / 'ecSecB_guess.csv')
gibbs_guess = self.hdxm_apo.guess_deltaG(initial_rates['rate']).to_numpy()

#todo allow contextmanger?
cfg.set('fitting', 'device', 'cuda')
cfg.set('fitting', 'dtype', 'float32')

fr_global = fit_gibbs_global(self.hdxm_apo, gibbs_guess, epochs=20000, r1=2)
out_deltaG = fr_global.output

for field in ['deltaG', 'k_obs']:
assert_series_equal(check_deltaG[field], out_deltaG[field], rtol=0.01, check_dtype=False)

cfg.set('fitting', 'device', 'cpu')
cfg.set('fitting', 'dtype', 'float64')


def test_batch_fit(self, tmp_path):
hdx_set = HDXMeasurementSet([self.hdxm_apo, self.hdxm_dimer])
guess = csv_to_dataframe(output_dir / 'ecSecB_guess.csv')
Expand Down

0 comments on commit 64a63aa

Please sign in to comment.