Skip to content

Commit

Permalink
Implements RaSP for more than one wildtype
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Oct 25, 2023
1 parent 2becc26 commit 9473941
Show file tree
Hide file tree
Showing 11 changed files with 12,154 additions and 60 deletions.
1,953 changes: 1,953 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/1uis.pdb

Large diffs are not rendered by default.

1,778 changes: 1,778 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/1uis_A_Repair.pdb

Large diffs are not rendered by default.

2,081 changes: 2,081 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/2vae.pdb

Large diffs are not rendered by default.

1,757 changes: 1,757 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/2vae_A_Repair.pdb

Large diffs are not rendered by default.

2,619 changes: 2,619 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/3ned.pdb

Large diffs are not rendered by default.

1,830 changes: 1,830 additions & 0 deletions examples/comparing_rasp_and_foldx/example_pdbs/3ned_A_Repair.pdb

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions examples/comparing_rasp_and_foldx/simple_rasp_foldx_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
In this example, we create both a RaSP and a FoldX objective function
and we compare their predictions of stability.
"""

from pathlib import Path

from poli import objective_factory

THIS_DIR = Path(__file__).parent.resolve()

if __name__ == "__main__":
wildtype_pdb_paths_for_foldx = list((THIS_DIR / "example_pdbs").glob("*_Repair.pdb"))
wildtype_pdb_paths_for_rasp = list((THIS_DIR / "example_pdbs").glob("*.pdb"))
wildtype_pdb_paths_for_rasp = [
path_ for path_ in wildtype_pdb_paths_for_rasp if "_Repair" not in str(path_.name)
]

# _, f_foldx, x0, y0, _ = objective_factory.create(
# name="foldx_stability",
# wildtype_pdb_path=wildtype_pdb_paths,
# batch_size=1,
# )

# print(f_foldx(x0))

_, f_rasp, x0, y0, _ = objective_factory.create(
name="rasp",
wildtype_pdb_path=wildtype_pdb_paths_for_rasp,
)

print(f_rasp(x0))
50 changes: 38 additions & 12 deletions src/poli/core/util/proteins/rasp/rasp_interface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
This module takes and adapts RaSP's original implementation
(which can be found at [1]), and writes an interface that
handles the preprocessing and inference steps.
[1] TODO: add.
"""

from typing import List
from pathlib import Path
import os, stat
Expand Down Expand Up @@ -380,7 +388,9 @@ def create_df_structure(
# we avoid the above two-for-loops. But they're
# practically instantaneous, so it's not a big deal.
# (O(nm) doesn't matter all that much if n < 1000
# and m is always 20)
# , m is always 20, and the cost of each operation
# is negligible)
df_structure["mutant_residue_string"] = [""] * len(df_structure)
if mutant_residue_strings is not None:
# Compute the mutations associated to all strings in
# mutant_residue_strings.
Expand All @@ -397,11 +407,21 @@ def create_df_structure(
position_in_chain = res_info.iloc[0]["pos"]
mutant_residue_as_single_character = wildtype_residue_string[0]

mutations_in_rasp_format.append(
mutation_in_rasp_format = (
original_residue_as_single_character
+ f"{position_in_chain}"
+ mutant_residue_as_single_character
)

mutations_in_rasp_format.append(mutation_in_rasp_format)

mask_for_this_mutation = df_structure["variant"].str.startswith(
mutation_in_rasp_format
)

df_structure.loc[
mask_for_this_mutation, "mutant_residue_string"
] = mutant_residue_string
continue

edits_ = edits_between_strings(
Expand All @@ -412,30 +432,36 @@ def create_df_structure(
original_residue_as_single_character = wildtype_residue_string[i]
position_in_chain = res_info.iloc[i]["pos"]
mutant_residue_as_single_character = mutant_residue_string[i]

mutations_in_rasp_format.append(
mutation_in_rasp_format = (
original_residue_as_single_character
+ f"{position_in_chain}"
+ mutant_residue_as_single_character
)

mutations_in_rasp_format.append(mutation_in_rasp_format)

# Add the mutant residue string to the dataframe.
mask_for_this_mutation = df_structure["variant"].str.startswith(
mutation_in_rasp_format
)

df_structure.loc[
mask_for_this_mutation, "mutant_residue_string"
] = mutant_residue_string

# Filter df_structure to only contain the mutations in
# mutations_in_rasp_format. (i.e. we need to focus on
# only some of the positions, which are in the middle
# of each string in mutations_in_rasp_format).
df_structure = df_structure[
df_structure["variant"].str.startswith(
tuple(
set(
[
mutation_in_rasp_format
for mutation_in_rasp_format in mutations_in_rasp_format
]
)
)
tuple(set(mutations_in_rasp_format))
)
]

# We should attach the original mutant string to each
# row in df_structure.

df_structure.drop(columns="index", inplace=True)

# Load PDB amino acid frequencies used to approximate unfolded states
Expand Down
8 changes: 8 additions & 0 deletions src/poli/objective_repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,11 @@
AVAILABLE_PROBLEM_FACTORIES["gfp_select"] = GFPSelectionProblemFactory
except (ImportError, FileNotFoundError):
pass


try:
from .rasp.register import RaspProblemFactory

AVAILABLE_PROBLEM_FACTORIES["rasp"] = RaspProblemFactory
except (ImportError, FileNotFoundError):
pass
2 changes: 1 addition & 1 deletion src/poli/objective_repository/rasp/environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: poli__tdc
name: poli__rasp
channels:
- omnia
- conda-forge
Expand Down
104 changes: 57 additions & 47 deletions src/poli/objective_repository/rasp/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path
from uuid import uuid4
from time import time
from collections import defaultdict

from poli.core.abstract_black_box import AbstractBlackBox
from poli.core.abstract_problem_factory import AbstractProblemFactory
Expand Down Expand Up @@ -225,7 +226,10 @@ def _black_box(self, x, context=None):
# sequences in x. For this, we need to compute the
# Hamming distance between each of the sequences in x
# and each of the wildtypes in self.wildtype_residue_strings.
closest_wildtypes = []

# closest_wildtypes will be a dictionary
# of the form {wildtype_path: List[str] of mutations}
closest_wildtypes = defaultdict(list)
mutant_residue_strings = []
for x_i in x:
# Assuming x_i is an array of strings
Expand All @@ -242,30 +246,57 @@ def _black_box(self, x, context=None):
if hamming_distance > 1:
raise ValueError("RaSP is only able to simulate single mutations.")

closest_wildtypes.append(closest_wildtype_pdb_file)
closest_wildtypes[closest_wildtype_pdb_file].append(mutant_residue_string)
mutant_residue_strings.append(mutant_residue_string)

# STEP 2:
# Loading the models in preparation for inference
cavity_model_net, ds_model_net = load_cavity_and_downstream_models()
dataset_key = "predictions"

# STEP 2 and 3:
# Creating the dataframe with the relevant mutations
df_structure = self.rasp_interface.create_df_structure(
# PER wildtype pdb file.

# We will store the results in a dictionary
# of the form {mutant_string: score}.
results = {}
for (
closest_wildtype_pdb_file,
mutant_residue_strings=mutant_residue_strings,
)
mutant_residue_strings_for_wildtype,
) in closest_wildtypes.items():
df_structure = self.rasp_interface.create_df_structure(
closest_wildtype_pdb_file,
mutant_residue_strings=mutant_residue_strings_for_wildtype,
)

# STEP 3:
# Predicting
cavity_model_net, ds_model_net = load_cavity_and_downstream_models()
dataset_key = "predictions"
df_ml = self.rasp_interface.predict(
cavity_model_net,
ds_model_net,
df_structure,
dataset_key,
RASP_NUM_ENSEMBLE,
RASP_DEVICE,
)
# STEP 3:
# Predicting
df_ml = self.rasp_interface.predict(
cavity_model_net,
ds_model_net,
df_structure,
dataset_key,
RASP_NUM_ENSEMBLE,
RASP_DEVICE,
)

return df_ml["score_ml"].values.reshape(x.shape[0], 1)
for (
mutant_residue_string_for_wildtype
) in mutant_residue_strings_for_wildtype:
results[mutant_residue_string_for_wildtype] = df_ml["score_ml"][
df_ml["mutant_residue_string"] == mutant_residue_string_for_wildtype
].values

# To reconstruct the final score, we rely
# on mutant_residue_strings, which is a list
# of strings IN THE SAME ORDER as the input
# vector x.
return np.array(
[
results[mutant_residue_string]
for mutant_residue_string in mutant_residue_strings
]
).reshape(-1, 1)


class RaspProblemFactory(AbstractProblemFactory):
Expand All @@ -276,7 +307,7 @@ def get_setup_information(self) -> ProblemSetupInformation:
alphabet = AMINO_ACIDS

return ProblemSetupInformation(
name="foldx_stability_and_sasa",
name="rasp",
max_sequence_length=np.inf,
alphabet=alphabet,
aligned=False,
Expand Down Expand Up @@ -349,31 +380,10 @@ def create(


if __name__ == "__main__":
THIS_DIR = Path(__file__).parent.resolve()
wildtype_pdb_path = THIS_DIR / "101m.pdb"
chain_to_keep = "A"

cavity_model_net, ds_model_net = load_cavity_and_downstream_models()
rasp_interface = RaspInterface(THIS_DIR / "tmp")

rasp_interface.raw_pdb_to_unique_chain(wildtype_pdb_path, chain_to_keep)
rasp_interface.unique_chain_to_clean_pdb(wildtype_pdb_path)
rasp_interface.cleaned_to_parsed_pdb(wildtype_pdb_path)

df_structure = rasp_interface.create_df_structure(wildtype_pdb_path)
print(df_structure.head())

# Loading the models

# Predicting
dataset_key = "predictions"
df_ml = rasp_interface.predict(
cavity_model_net,
ds_model_net,
df_structure,
dataset_key,
RASP_NUM_ENSEMBLE,
RASP_DEVICE,
)
from poli.core.registry import register_problem

print(df_ml.head())
rasp_problem_factory = RaspProblemFactory()
register_problem(
rasp_problem_factory,
conda_environment_name="poli__rasp",
)

0 comments on commit 9473941

Please sign in to comment.