Skip to content

Commit

Permalink
Finished all TODO's
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed Jun 17, 2024
1 parent e20241d commit bee4544
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 41 deletions.
8 changes: 4 additions & 4 deletions multidms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class works to compose, compile, and optimize the model parameters
__version__ = "0.4.0"
__url__ = "https://github.com/matsengrp/multidms"

from polyclonal.alphabets import AAS # noqa: F401
from polyclonal.alphabets import AAS_WITHGAP # noqa: F401
from polyclonal.alphabets import AAS_WITHSTOP # noqa: F401
from polyclonal.alphabets import AAS_WITHSTOP_WITHGAP # noqa: F401
from binarymap.binarymap import AAS_NOSTOP as AAS # noqa: F401
from binarymap.binarymap import AAS_WITHGAP # noqa: F401
from binarymap.binarymap import AAS_WITHSTOP # noqa: F401
from binarymap.binarymap import AAS_WITHSTOP_WITHGAP # noqa: F401

from multidms.data import Data # noqa: F401
from multidms.model import Model # noqa: F401
Expand Down
35 changes: 13 additions & 22 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
from jax.experimental import sparse
from jaxopt import ProximalGradient

# from jaxopt.linear_solve import solve_normal_cg
import seaborn as sns
from matplotlib import pyplot as plt


jax.config.update("jax_enable_x64", True)


Expand Down Expand Up @@ -594,9 +592,9 @@ def get_mutations_df(

for condition in self.data.conditions:
single_mut_binary = self.data.single_mut_encodings[condition]
mutations_df[f"predicted_func_score_{condition}"] = (
self.phenotype_frombinary(single_mut_binary, condition=condition)
)
mutations_df[
f"predicted_func_score_{condition}"
] = self.phenotype_frombinary(single_mut_binary, condition=condition)

if phenotype_as_effect:
wt_func_score = self.wildtype_df.loc[condition, "predicted_func_score"]
Expand Down Expand Up @@ -800,6 +798,8 @@ def add_phenotypes_to_df(
raise ValueError(f"`df` already contains column {col}")
ret[col] = onp.nan

# if the user has provided a name for the converted subs, then
# we need to add it to the dataframe
if converted_substitutions_col is not None:
ret[converted_substitutions_col] = ""

Expand All @@ -813,7 +813,6 @@ def add_phenotypes_to_df(
axis=1,
)

# TODO, why convert above if this may be provided by user?
if converted_substitutions_col is not None:
ret.loc[condition_df.index, converted_substitutions_col] = variant_subs

Expand Down Expand Up @@ -852,9 +851,9 @@ def add_phenotypes_to_df(
if phenotype_as_effect:
latent_predictions -= wildtype_df.loc[condition, "predicted_latent"]
latent_predictions[nan_variant_indices] = onp.nan
ret.loc[condition_df.index.values, latent_phenotype_col] = (
latent_predictions
)
ret.loc[
condition_df.index.values, latent_phenotype_col
] = latent_predictions

# func_score predictions on binary variants, X
phenotype_predictions = onp.array(
Expand All @@ -866,9 +865,9 @@ def add_phenotypes_to_df(
condition, "predicted_func_score"
]
phenotype_predictions[nan_variant_indices] = onp.nan
ret.loc[condition_df.index.values, observed_phenotype_col] = (
phenotype_predictions
)
ret.loc[
condition_df.index.values, observed_phenotype_col
] = phenotype_predictions

return ret

Expand Down Expand Up @@ -1002,8 +1001,7 @@ def fit(
If True, use FISTA acceleration. Defaults to True.
lock_params : dict
Dictionary of parameters, and desired value to constrain
them at during optimization. By default, none of the parameters
reference-condition latent offset (TODO math? beta0[ref]) are locked.
them at during optimization. By default, no parameters are locked.
admm_niter : int
Number of iterations to perform during the ADMM optimization.
Defaults to 50. Note that in the case of single-condition models,
Expand Down Expand Up @@ -1045,15 +1043,11 @@ def fit(
if upper_bound_ge_scale < 0:
raise ValueError("upper_bound_theta_ge_scale must be non-negative")

# TODO I wonder if this should be rounded?
if upper_bound_ge_scale == "infer":
y = jnp.concatenate(list(self.data.training_data["y"].values()))
y_range = y.max() - y.min()
upper_bound_ge_scale = 2 * y_range

# box constraints for the reference beta0 parameter.
# lock_params[("beta0", self.data.reference)] = 0.0

compiled_proximal = self._model_components["proximal"]
compiled_objective = jax.jit(self._model_components["objective"])

Expand Down Expand Up @@ -1118,14 +1112,12 @@ def fit(
jit=False,
)

# GET DATA
# get training data
scaled_training_data = (
self._data.scaled_training_data["X"],
self._data.scaled_training_data["y"],
)

# TODO get validation data if it exists?

self._state = solver.init_state(
self._scaled_data_params,
hyperparams_prox=hyperparams_prox,
Expand All @@ -1137,7 +1129,6 @@ def fit(
index=range(0, maxiter + 1, convergence_trajectory_resolution)
).assign(loss=onp.nan, error=onp.nan)

# TODO should step be the index?
convergence_trajectory.index.name = "step"

# record initial loss and error
Expand Down
18 changes: 11 additions & 7 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""
Contains the ModelCollection class, which takes a collection of models
================
model_collection
================
Contains the :class:`ModelCollection` class, which takes a collection of models
and merges the results for comparison and visualization.
"""

Expand Down Expand Up @@ -556,7 +560,10 @@ def convergence_trajectory_df(
query=None,
id_vars=("dataset_name", "scale_coeff_lasso_shift"),
):
"""TODO"""
"""
Combine the converence trajectory dataframes of
all fits in the queried collection.
"""
queried_fits = (
self.fit_models.query(query) if query is not None else self.fit_models
)
Expand All @@ -577,7 +584,6 @@ def convergence_trajectory_df(
]
)

# TODO make altair chart
return convergence_trajectory_data

def mut_param_heatmap(
Expand Down Expand Up @@ -726,7 +732,6 @@ def mut_param_heatmap(
# melt conditions and stats cols, beta is already "tall"
# note that we must rename conditions with "." in the
# name to "_" to avoid altair errors
# TODO let's just make sure we don't have "." in the condition names
if mut_param == "beta":
muts_df_tall = muts_df.assign(condition=self.reference.replace(".", "_"))
else:
Expand Down Expand Up @@ -883,12 +888,11 @@ def mut_type(mut):
column=alt.Column("condition", title="Experiment"),
)

# TODO fix height scalar
def shift_sparsity(
self,
x="scale_coeff_lasso_shift",
width_scalar=100,
height_scalar=10,
height_scalar=100,
return_data=False,
**kwargs,
):
Expand Down Expand Up @@ -991,7 +995,7 @@ def mut_type(mut):
def mut_param_dataset_correlation(
self,
x="scale_coeff_lasso_shift",
width_scalar=150,
width_scalar=200,
height=200,
return_data=False,
r=2,
Expand Down
47 changes: 39 additions & 8 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Tests for the Data class and its methods."""

# import traceback
# import warnings
# import sys
# import os

import pytest
import multidms
import numpy as np
Expand All @@ -15,7 +10,6 @@

import multidms.utils

# TODO test non numeric sites?
TEST_FUNC_SCORES = pd.read_csv(
StringIO(
"""
Expand All @@ -34,8 +28,6 @@
)
)

# TODO figure out correct way to setup the data
# using pytest fixtures.
data = multidms.Data(
TEST_FUNC_SCORES,
alphabet=multidms.AAS_WITHSTOP,
Expand Down Expand Up @@ -652,6 +644,45 @@ def test_fit_models():
assert list(tall_combined.index.names) == ["scale_coeff_lasso_shift"]


def test_ModelCollection_mut_param_dataset_correlation():
"""
Test that the correlation between the mutational
parameter estimates across conditions is correct.
by correlating two deterministic model fits from identical
datasets, meaning they should have a correlation of 1.0.
"""
data_rep1 = multidms.Data(
TEST_FUNC_SCORES,
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
name="rep1",
)

data_rep2 = multidms.Data(
TEST_FUNC_SCORES,
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
name="rep2",
)

params = {
"dataset": [data_rep1, data_rep2],
"maxiter": [1],
"scale_coeff_lasso_shift": [0.0],
}

_, _, fit_models_df = multidms.model_collection.fit_models(
params,
n_threads=-1,
)
mc = multidms.model_collection.ModelCollection(fit_models_df)
chart, data = mc.mut_param_dataset_correlation(return_data=True)

assert np.all(data["correlation"] == 1.0)


def test_ModelCollection_charts():
"""
Test fitting two different models in
Expand Down

0 comments on commit bee4544

Please sign in to comment.