Skip to content

Commit

Permalink
Adds additivity to RaSP
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Aug 20, 2024
1 parent bc8a099 commit c2e6ab9
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 21 deletions.
76 changes: 62 additions & 14 deletions src/poli/objective_repository/rasp/isolated_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from uuid import uuid4

import numpy as np
import torch

from poli.core.abstract_isolated_function import AbstractIsolatedFunction
from poli.core.util.proteins.mutations import find_closest_wildtype_pdb_file_to_mutant
Expand All @@ -37,7 +38,10 @@
)

RASP_NUM_ENSEMBLE = 10
RASP_DEVICE = "cpu"
if torch.cuda.is_available():
RASP_DEVICE = "cuda"
else:
RASP_DEVICE = "cpu"

THIS_DIR = Path(__file__).parent.resolve()
HOME_DIR = THIS_DIR.home()
Expand Down Expand Up @@ -77,6 +81,11 @@ class RaspIsolatedLogic(AbstractIsolatedFunction):
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s), by default None.
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -115,6 +124,7 @@ class RaspIsolatedLogic(AbstractIsolatedFunction):
def __init__(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -126,6 +136,11 @@ def __init__(
-----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -234,6 +249,7 @@ def __init__(
x0_pre_array = [x + [""] * (max_len - len(x)) for x in x0_pre_array]

self.x0 = np.array(x0_pre_array)
self.additive = additive

def _clean_wildtype_pdb_files(self):
"""
Expand Down Expand Up @@ -302,10 +318,9 @@ def __call__(self, x, context=None):
of the longest sequence in the batch, and b is the batch size.
We process it by concantenating the array into a single string,
where we assume the padding to be an empty string (if there was any).
Each of these x_i's will be matched to the wildtype in self. wildtype_residue_strings with the lowest Hamming distance.
Each of these x_i's will be matched to the wildtype in
self.wildtype_residue_strings with the lowest Hamming distance.
"""
# Creating an interface for this experiment id

# We need to find the closest wildtype to each of the
# sequences in x. For this, we need to compute the
# Hamming distance between each of the sequences in x
Expand All @@ -315,6 +330,7 @@ def __call__(self, x, context=None):
# of the form {wildtype_path: List[str] of mutations}
closest_wildtypes = defaultdict(list)
mutant_residue_strings = []
mutant_residue_to_hamming_distances = dict()
for x_i in x:
# Assuming x_i is an array of strings
mutant_residue_string = "".join(x_i)
Expand All @@ -327,11 +343,14 @@ def __call__(self, x, context=None):
return_hamming_distance=True,
)

if hamming_distance > 1:
if hamming_distance > 1 and not self.additive:
raise ValueError("RaSP is only able to simulate single mutations.")

closest_wildtypes[closest_wildtype_pdb_file].append(mutant_residue_string)
mutant_residue_strings.append(mutant_residue_string)
mutant_residue_to_hamming_distances[mutant_residue_string] = (
hamming_distance
)

# Loading the models in preparation for inference
cavity_model_net, ds_model_net = load_cavity_and_downstream_models()
Expand Down Expand Up @@ -367,9 +386,35 @@ def __call__(self, x, context=None):
for (
mutant_residue_string_for_wildtype
) in mutant_residue_strings_for_wildtype:
results[mutant_residue_string_for_wildtype] = df_ml["score_ml"][
sliced_values_for_mutant = df_ml["score_ml"][
df_ml["mutant_residue_string"] == mutant_residue_string_for_wildtype
].values
results[mutant_residue_string_for_wildtype] = sliced_values_for_mutant

if self.additive:
assert (
sliced_values_for_mutant.shape[0]
== mutant_residue_to_hamming_distances[
mutant_residue_string_for_wildtype
]
), (
" The number of predictions made for this mutant"
" is not equal to the Hamming distance between the"
" mutant and the wildtype.\n"
"This is an internal error in `poli`. Please report"
" this issue by referencing the RaSP problem.\n"
" https://github.com/MachineLearningLifeScience/poli/issues"
)

# If we are treating the mutations as additive
# the sliced values for mutant will be an array
# of length equal to the Hamming distance between
# the mutant and the wildtype. These are the individual
# mutation predictions made by RaSP. We need to sum
# them up to get the final score.
results[mutant_residue_string_for_wildtype] = np.sum(
sliced_values_for_mutant
)

# To reconstruct the final score, we rely
# on mutant_residue_strings, which is a list
Expand All @@ -384,11 +429,14 @@ def __call__(self, x, context=None):


if __name__ == "__main__":
from poli.core.registry import register_isolated_function

register_isolated_function(
RaspIsolatedLogic,
name="rasp__isolated",
conda_environment_name="poli__rasp",
force=True,
)
# from poli.core.registry import register_isolated_function

# register_isolated_function(
# RaspIsolatedLogic,
# name="rasp__isolated",
# conda_environment_name="poli__rasp",
# force=True,
# )

PDB_DIR = THIS_DIR.parent.parent / "tests" / "registry" / "proteins" / "3ned.pdb"
f = RaspIsolatedLogic(wildtype_pdb_path=[PDB_DIR])
34 changes: 27 additions & 7 deletions src/poli/objective_repository/rasp/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class RaspBlackBox(AbstractBlackBox):
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s), by default None.
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -83,6 +88,7 @@ class RaspBlackBox(AbstractBlackBox):
def __init__(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -99,6 +105,18 @@ def __init__(
-----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
The temporary folder path, by default None.
batch_size : int, optional
The batch size for parallel evaluation, by default None.
parallelize : bool, optional
Expand All @@ -107,13 +125,6 @@ def __init__(
The number of workers for parallel evaluation, by default None.
evaluation_budget : int, optional
The evaluation budget, by default float("inf").
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
The temporary folder path, by default None.
Raises:
-------
Expand Down Expand Up @@ -147,12 +158,14 @@ def __init__(
self.chains_to_keep = chains_to_keep
self.experiment_id = experiment_id
self.tmp_folder = tmp_folder
self.additive = additive
self.inner_function = get_inner_function(
isolated_function_name="rasp__isolated",
class_name="RaspIsolatedLogic",
module_to_import="poli.objective_repository.rasp.isolated_function",
force_isolation=self.force_isolation,
wildtype_pdb_path=self.wildtype_pdb_path,
additive=self.additive,
chains_to_keep=self.chains_to_keep,
experiment_id=self.experiment_id,
tmp_folder=self.tmp_folder,
Expand Down Expand Up @@ -207,6 +220,7 @@ class RaspProblemFactory(AbstractProblemFactory):
def create(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -225,6 +239,11 @@ def create(
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive: bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
Expand Down Expand Up @@ -276,6 +295,7 @@ def create(

f = RaspBlackBox(
wildtype_pdb_path=wildtype_pdb_path,
additive=additive,
chains_to_keep=chains_to_keep,
experiment_id=experiment_id,
tmp_folder=tmp_folder,
Expand Down
40 changes: 40 additions & 0 deletions src/poli/tests/registry/proteins/test_rasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_rasp_on_3ned_against_notebooks_results_isolated():
problem = objective_factory.create(
name="rasp",
wildtype_pdb_path=THIS_DIR / "3ned.pdb",
force_isolation=True,
)
f, x0 = problem.black_box, problem.x0

Expand Down Expand Up @@ -80,3 +81,42 @@ def test_rasp_on_3ned_against_notebooks_results_isolated():
assert np.isclose(y[0], 0.0365, atol=1e-4)
assert np.isclose(y[1], -0.07091, atol=1e-4)
assert np.isclose(y[2], -0.283559, atol=1e-4)


@pytest.mark.poli__rasp
def test_rasp_using_additive_flag_on_two_mutations():
import torch

# For us to match what the notebook says, we have
# to run at double precision.
torch.set_default_dtype(torch.float64)

# If the previous import was successful, we can
# create a RaSP problem:
problem = objective_factory.create(
name="rasp",
wildtype_pdb_path=THIS_DIR / "3ned.pdb",
additive=True,
)
f, x0 = problem.black_box, problem.x0

wildtype_sequence = "".join(x0[0])
one_mutant_with_two_mutations = [
"AR" + wildtype_sequence[2:],
]
two_mutations = [
"A" + wildtype_sequence[1:],
wildtype_sequence[:1] + "R" + wildtype_sequence[2:],
]

x = np.array([list(mutation) for mutation in one_mutant_with_two_mutations])
y = f(x)

x1 = np.array([list(mutation) for mutation in two_mutations])
y1 = f(x1)

assert y == y1.sum()


if __name__ == "__main__":
test_rasp_using_additive_flag_on_two_mutations()

0 comments on commit c2e6ab9

Please sign in to comment.