From bee454476ae6dd0b0c2aebdf626a7ee79b8f11b5 Mon Sep 17 00:00:00 2001 From: jgallowa07 Date: Tue, 11 Jun 2024 11:56:14 -0700 Subject: [PATCH] Finished all TODO's --- multidms/__init__.py | 8 +++--- multidms/model.py | 35 ++++++++++----------------- multidms/model_collection.py | 18 ++++++++------ tests/test_data.py | 47 ++++++++++++++++++++++++++++++------ 4 files changed, 67 insertions(+), 41 deletions(-) diff --git a/multidms/__init__.py b/multidms/__init__.py index d4d6001..def7b2c 100644 --- a/multidms/__init__.py +++ b/multidms/__init__.py @@ -50,10 +50,10 @@ class works to compose, compile, and optimize the model parameters __version__ = "0.4.0" __url__ = "https://github.com/matsengrp/multidms" -from polyclonal.alphabets import AAS # noqa: F401 -from polyclonal.alphabets import AAS_WITHGAP # noqa: F401 -from polyclonal.alphabets import AAS_WITHSTOP # noqa: F401 -from polyclonal.alphabets import AAS_WITHSTOP_WITHGAP # noqa: F401 +from binarymap.binarymap import AAS_NOSTOP as AAS # noqa: F401 +from binarymap.binarymap import AAS_WITHGAP # noqa: F401 +from binarymap.binarymap import AAS_WITHSTOP # noqa: F401 +from binarymap.binarymap import AAS_WITHSTOP_WITHGAP # noqa: F401 from multidms.data import Data # noqa: F401 from multidms.model import Model # noqa: F401 diff --git a/multidms/model.py b/multidms/model.py index 19375d1..eb19bc6 100644 --- a/multidms/model.py +++ b/multidms/model.py @@ -26,11 +26,9 @@ from jax.experimental import sparse from jaxopt import ProximalGradient -# from jaxopt.linear_solve import solve_normal_cg import seaborn as sns from matplotlib import pyplot as plt - jax.config.update("jax_enable_x64", True) @@ -594,9 +592,9 @@ def get_mutations_df( for condition in self.data.conditions: single_mut_binary = self.data.single_mut_encodings[condition] - mutations_df[f"predicted_func_score_{condition}"] = ( - self.phenotype_frombinary(single_mut_binary, condition=condition) - ) + mutations_df[ + f"predicted_func_score_{condition}" + ] = self.phenotype_frombinary(single_mut_binary, condition=condition) if phenotype_as_effect: wt_func_score = self.wildtype_df.loc[condition, "predicted_func_score"] @@ -800,6 +798,8 @@ def add_phenotypes_to_df( raise ValueError(f"`df` already contains column {col}") ret[col] = onp.nan + # if the user has provided a name for the converted subs, then + # we need to add it to the dataframe if converted_substitutions_col is not None: ret[converted_substitutions_col] = "" @@ -813,7 +813,6 @@ def add_phenotypes_to_df( axis=1, ) - # TODO, why convert above if this may be provided by user? if converted_substitutions_col is not None: ret.loc[condition_df.index, converted_substitutions_col] = variant_subs @@ -852,9 +851,9 @@ def add_phenotypes_to_df( if phenotype_as_effect: latent_predictions -= wildtype_df.loc[condition, "predicted_latent"] latent_predictions[nan_variant_indices] = onp.nan - ret.loc[condition_df.index.values, latent_phenotype_col] = ( - latent_predictions - ) + ret.loc[ + condition_df.index.values, latent_phenotype_col + ] = latent_predictions # func_score predictions on binary variants, X phenotype_predictions = onp.array( @@ -866,9 +865,9 @@ def add_phenotypes_to_df( condition, "predicted_func_score" ] phenotype_predictions[nan_variant_indices] = onp.nan - ret.loc[condition_df.index.values, observed_phenotype_col] = ( - phenotype_predictions - ) + ret.loc[ + condition_df.index.values, observed_phenotype_col + ] = phenotype_predictions return ret @@ -1002,8 +1001,7 @@ def fit( If True, use FISTA acceleration. Defaults to True. lock_params : dict Dictionary of parameters, and desired value to constrain - them at during optimization. By default, none of the parameters - reference-condition latent offset (TODO math? beta0[ref]) are locked. + them at during optimization. By default, no parameters are locked. admm_niter : int Number of iterations to perform during the ADMM optimization. Defaults to 50. Note that in the case of single-condition models, @@ -1045,15 +1043,11 @@ def fit( if upper_bound_ge_scale < 0: raise ValueError("upper_bound_theta_ge_scale must be non-negative") - # TODO I wonder if this should be rounded? if upper_bound_ge_scale == "infer": y = jnp.concatenate(list(self.data.training_data["y"].values())) y_range = y.max() - y.min() upper_bound_ge_scale = 2 * y_range - # box constraints for the reference beta0 parameter. - # lock_params[("beta0", self.data.reference)] = 0.0 - compiled_proximal = self._model_components["proximal"] compiled_objective = jax.jit(self._model_components["objective"]) @@ -1118,14 +1112,12 @@ def fit( jit=False, ) - # GET DATA + # get training data scaled_training_data = ( self._data.scaled_training_data["X"], self._data.scaled_training_data["y"], ) - # TODO get validation data if it exists? - self._state = solver.init_state( self._scaled_data_params, hyperparams_prox=hyperparams_prox, @@ -1137,7 +1129,6 @@ def fit( index=range(0, maxiter + 1, convergence_trajectory_resolution) ).assign(loss=onp.nan, error=onp.nan) - # TODO should step be the index? convergence_trajectory.index.name = "step" # record initial loss and error diff --git a/multidms/model_collection.py b/multidms/model_collection.py index 0da8d1b..700433a 100644 --- a/multidms/model_collection.py +++ b/multidms/model_collection.py @@ -1,5 +1,9 @@ """ -Contains the ModelCollection class, which takes a collection of models +================ +model_collection +================ + +Contains the :class:`ModelCollection` class, which takes a collection of models and merges the results for comparison and visualization. """ @@ -556,7 +560,10 @@ def convergence_trajectory_df( query=None, id_vars=("dataset_name", "scale_coeff_lasso_shift"), ): - """TODO""" + """ + Combine the converence trajectory dataframes of + all fits in the queried collection. + """ queried_fits = ( self.fit_models.query(query) if query is not None else self.fit_models ) @@ -577,7 +584,6 @@ def convergence_trajectory_df( ] ) - # TODO make altair chart return convergence_trajectory_data def mut_param_heatmap( @@ -726,7 +732,6 @@ def mut_param_heatmap( # melt conditions and stats cols, beta is already "tall" # note that we must rename conditions with "." in the # name to "_" to avoid altair errors - # TODO let's just make sure we don't have "." in the condition names if mut_param == "beta": muts_df_tall = muts_df.assign(condition=self.reference.replace(".", "_")) else: @@ -883,12 +888,11 @@ def mut_type(mut): column=alt.Column("condition", title="Experiment"), ) - # TODO fix height scalar def shift_sparsity( self, x="scale_coeff_lasso_shift", width_scalar=100, - height_scalar=10, + height_scalar=100, return_data=False, **kwargs, ): @@ -991,7 +995,7 @@ def mut_type(mut): def mut_param_dataset_correlation( self, x="scale_coeff_lasso_shift", - width_scalar=150, + width_scalar=200, height=200, return_data=False, r=2, diff --git a/tests/test_data.py b/tests/test_data.py index f68e2b9..143a37f 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,10 +1,5 @@ """Tests for the Data class and its methods.""" -# import traceback -# import warnings -# import sys -# import os - import pytest import multidms import numpy as np @@ -15,7 +10,6 @@ import multidms.utils -# TODO test non numeric sites? TEST_FUNC_SCORES = pd.read_csv( StringIO( """ @@ -34,8 +28,6 @@ ) ) -# TODO figure out correct way to setup the data -# using pytest fixtures. data = multidms.Data( TEST_FUNC_SCORES, alphabet=multidms.AAS_WITHSTOP, @@ -652,6 +644,45 @@ def test_fit_models(): assert list(tall_combined.index.names) == ["scale_coeff_lasso_shift"] +def test_ModelCollection_mut_param_dataset_correlation(): + """ + Test that the correlation between the mutational + parameter estimates across conditions is correct. + by correlating two deterministic model fits from identical + datasets, meaning they should have a correlation of 1.0. + """ + data_rep1 = multidms.Data( + TEST_FUNC_SCORES, + alphabet=multidms.AAS_WITHSTOP, + reference="a", + assert_site_integrity=False, + name="rep1", + ) + + data_rep2 = multidms.Data( + TEST_FUNC_SCORES, + alphabet=multidms.AAS_WITHSTOP, + reference="a", + assert_site_integrity=False, + name="rep2", + ) + + params = { + "dataset": [data_rep1, data_rep2], + "maxiter": [1], + "scale_coeff_lasso_shift": [0.0], + } + + _, _, fit_models_df = multidms.model_collection.fit_models( + params, + n_threads=-1, + ) + mc = multidms.model_collection.ModelCollection(fit_models_df) + chart, data = mc.mut_param_dataset_correlation(return_data=True) + + assert np.all(data["correlation"] == 1.0) + + def test_ModelCollection_charts(): """ Test fitting two different models in