Skip to content

Commit

Permalink
removed int cast and added test (#127)
Browse files Browse the repository at this point in the history
# closes #127
  • Loading branch information
jgallowa07 authored Dec 6, 2023
1 parent ee8e688 commit 8e44d35
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
2 changes: 1 addition & 1 deletion multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ def mut_param_heatmap(
continue
addtl_tooltip_stats.append(f"wildtype_{condition}")
muts_df[f"wildtype_{condition}"] = muts_df.site.apply(
lambda site: site_map.loc[int(site), condition]
lambda site: site_map.loc[site, condition]
)

# melt conditions and stats cols, beta is already "tall"
Expand Down
2 changes: 1 addition & 1 deletion multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def mut_param_heatmap(
continue
addtl_tooltip_stats.append(f"wildtype_{condition}")
muts_df[f"wildtype_{condition}"] = muts_df.site.apply(
lambda site: self.site_map_union.loc[int(site), condition]
lambda site: self.site_map_union.loc[site, condition]
)

# melt conditions and stats cols, beta is already "tall"
Expand Down
26 changes: 26 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,32 @@ def test_fit_models():
assert list(tall_combined.index.names) == ["scale_coeff_lasso_shift"]


def test_ModelCollection_charts():
"""
Test fitting two different models in
parallel using multidms.model_collection.fit_models
"""
data = multidms.Data(
TEST_FUNC_SCORES,
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
)
params = {
"dataset": [data],
"iterations_per_step": [2],
"scale_coeff_lasso_shift": [0.0, 1e-5],
}
_, _, fit_models_df = multidms.model_collection.fit_models(
params,
n_threads=min(os.cpu_count(), 4),
)
mc = multidms.model_collection.ModelCollection(fit_models_df)

mc.mut_param_heatmap(query="scale_coeff_lasso_shift == 0.0")
mc.shift_sparsity()


def test_data_names():
"""
Test that the default data names are correctly
Expand Down

0 comments on commit 8e44d35

Please sign in to comment.