Skip to content

Commit

Permalink
Adds additivity to RaSP (#246)
Browse files Browse the repository at this point in the history
* Adds additivity to RaSP

* Adds an workflow for testing RaSP

* Fixes matplotlib version for rasp

* Registers rasp and removes testing code in isolated function

* Rewords a test to work without running on rasp env

* Updates the rasp action description

* Cleans up test comments and makes an error more verbose

* Adds link to RaSP action in readme

* Bumps version
  • Loading branch information
miguelgondu authored Aug 20, 2024
1 parent bc8a099 commit c2e52da
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 29 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/python-tox-testing-rasp-env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: poli rasp (conda, py3.9)

on:
push:
schedule:
- cron: '0 0 * * 0'

jobs:
build-linux:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: '3.9'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
python -m pip install tox
- name: Test rasp-related black boxes with tox and pytest
run: |
tox -c tox.ini -e poli-rasp-py39
2 changes: 1 addition & 1 deletion README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
| [Ehrlich functions](https://machinelearninglifescience.github.io/poli-docs/using_poli/objective_repository/ehrlich_functions.html) | [(Stanton et al. 2024)](https://arxiv.org/abs/2407.00236) | [![poli base (dev, conda, python 3.9)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-base.yml/badge.svg)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-base.yml)
| [PMO/GuacaMol benchmark](https://machinelearninglifescience.github.io/poli-docs/#small-molecules) | [(Brown et al. 2019)](https://arxiv.org/abs/1811.09621), [(Gao et al. 2022)](https://openreview.net/forum?id=yCZRdI0Y7G), [(Huang et al. 2021)](https://openreview.net/pdf?id=8nvgnORnoWr) | [![poli tdc (dev, conda, python 3.9)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-tdc-env.yml/badge.svg)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-tdc-env.yml)
| [Dockstring](https://machinelearninglifescience.github.io/poli-docs/using_poli/objective_repository/dockstring.html) | [(García-Ortegón et al. 2022)](https://pubs.acs.org/doi/full/10.1021/acs.jcim.1c01334) | [![poli dockstring (dev, conda, python 3.9)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-dockstring-env.yml/badge.svg)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-dockstring-env.yml)
| [RaSP](https://machinelearninglifescience.github.io/poli-docs/using_poli/objective_repository/RaSP.html) | [(Blaabjerg et al. 2023)](https://elifesciences.org/articles/82593) | [![poli rasp (conda, py3.9)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-rasp-env.yml/badge.svg)](https://github.com/MachineLearningLifeScience/poli/actions/workflows/python-tox-testing-rasp-env.yml)
| [FoldX stability and SASA](https://machinelearninglifescience.github.io/poli-docs/#proteins) | [(Schymkowitz et al. 2005)](https://academic.oup.com/nar/article/33/suppl_2/W382/2505499?login=true) | - |
| [RaSP](https://machinelearninglifescience.github.io/poli-docs/using_poli/objective_repository/RaSP.html) | [(Blaabjerg et al. 2023)](https://elifesciences.org/articles/82593) | -

## Features
- 🔲 **isolation** of black box function calls inside conda environments. Don't worry about clashes w. black box requirements, poli will create the relevant conda environments for you.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "poli"
version = "1.0.0.dev3"
version = "1.0.0.dev4"
description = "poli, a library of discrete objective functions"
readme = "README.md"
authors = [{name="Miguel González-Duque", email="[email protected]"}, {name="Simon Bartels"}]
Expand Down Expand Up @@ -53,7 +53,7 @@ profile = "black"
exclude = ["src/poli/core/util/proteins/rasp/inner_rasp", "src/poli/objective_repository/gfp_cbas"]

[tool.bumpversion]
current_version = "1.0.0.dev3"
current_version = "1.0.0.dev4"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = poli
version = "1.0.0.dev3"
version = "1.0.0.dev4"
author_email = [email protected]
description = Protein Objectives Library
long_description = file: README.md
Expand Down
2 changes: 1 addition & 1 deletion src/poli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""poli, a library for discrete black-box objective functions."""

__version__ = "1.0.0.dev3"
__version__ = "1.0.0.dev4"
from .core.util.isolation.instancing import instance_function_as_isolated_process

# from .core import get_problems
Expand Down
2 changes: 1 addition & 1 deletion src/poli/objective_repository/rasp/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- biopython==1.72
- torch
- pandas
- matplotlib
- matplotlib==3.8.1
- pdb-tools
- ptitprince
- scikit-learn
Expand Down
64 changes: 57 additions & 7 deletions src/poli/objective_repository/rasp/isolated_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from uuid import uuid4

import numpy as np
import torch

from poli.core.abstract_isolated_function import AbstractIsolatedFunction
from poli.core.util.proteins.mutations import find_closest_wildtype_pdb_file_to_mutant
Expand All @@ -37,7 +38,10 @@
)

RASP_NUM_ENSEMBLE = 10
RASP_DEVICE = "cpu"
if torch.cuda.is_available():
RASP_DEVICE = "cuda"
else:
RASP_DEVICE = "cpu"

THIS_DIR = Path(__file__).parent.resolve()
HOME_DIR = THIS_DIR.home()
Expand Down Expand Up @@ -77,6 +81,11 @@ class RaspIsolatedLogic(AbstractIsolatedFunction):
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s), by default None.
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -115,6 +124,7 @@ class RaspIsolatedLogic(AbstractIsolatedFunction):
def __init__(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -126,6 +136,11 @@ def __init__(
-----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -234,6 +249,7 @@ def __init__(
x0_pre_array = [x + [""] * (max_len - len(x)) for x in x0_pre_array]

self.x0 = np.array(x0_pre_array)
self.additive = additive

def _clean_wildtype_pdb_files(self):
"""
Expand Down Expand Up @@ -302,10 +318,9 @@ def __call__(self, x, context=None):
of the longest sequence in the batch, and b is the batch size.
We process it by concantenating the array into a single string,
where we assume the padding to be an empty string (if there was any).
Each of these x_i's will be matched to the wildtype in self. wildtype_residue_strings with the lowest Hamming distance.
Each of these x_i's will be matched to the wildtype in
self.wildtype_residue_strings with the lowest Hamming distance.
"""
# Creating an interface for this experiment id

# We need to find the closest wildtype to each of the
# sequences in x. For this, we need to compute the
# Hamming distance between each of the sequences in x
Expand All @@ -315,6 +330,7 @@ def __call__(self, x, context=None):
# of the form {wildtype_path: List[str] of mutations}
closest_wildtypes = defaultdict(list)
mutant_residue_strings = []
mutant_residue_to_hamming_distances = dict()
for x_i in x:
# Assuming x_i is an array of strings
mutant_residue_string = "".join(x_i)
Expand All @@ -327,11 +343,19 @@ def __call__(self, x, context=None):
return_hamming_distance=True,
)

if hamming_distance > 1:
raise ValueError("RaSP is only able to simulate single mutations.")
if hamming_distance > 1 and not self.additive:
raise ValueError(
"RaSP is only able to simulate single mutations."
" If you want to simulate multiple mutations,"
" you should set additive=True in the create method"
" or in the black box of RaSP."
)

closest_wildtypes[closest_wildtype_pdb_file].append(mutant_residue_string)
mutant_residue_strings.append(mutant_residue_string)
mutant_residue_to_hamming_distances[mutant_residue_string] = (
hamming_distance
)

# Loading the models in preparation for inference
cavity_model_net, ds_model_net = load_cavity_and_downstream_models()
Expand Down Expand Up @@ -367,9 +391,35 @@ def __call__(self, x, context=None):
for (
mutant_residue_string_for_wildtype
) in mutant_residue_strings_for_wildtype:
results[mutant_residue_string_for_wildtype] = df_ml["score_ml"][
sliced_values_for_mutant = df_ml["score_ml"][
df_ml["mutant_residue_string"] == mutant_residue_string_for_wildtype
].values
results[mutant_residue_string_for_wildtype] = sliced_values_for_mutant

if self.additive:
assert (
sliced_values_for_mutant.shape[0]
== mutant_residue_to_hamming_distances[
mutant_residue_string_for_wildtype
]
), (
" The number of predictions made for this mutant"
" is not equal to the Hamming distance between the"
" mutant and the wildtype.\n"
"This is an internal error in `poli`. Please report"
" this issue by referencing the RaSP problem.\n"
" https://github.com/MachineLearningLifeScience/poli/issues"
)

# If we are treating the mutations as additive
# the sliced values for mutant will be an array
# of length equal to the Hamming distance between
# the mutant and the wildtype. These are the individual
# mutation predictions made by RaSP. We need to sum
# them up to get the final score.
results[mutant_residue_string_for_wildtype] = np.sum(
sliced_values_for_mutant
)

# To reconstruct the final score, we rely
# on mutant_residue_strings, which is a list
Expand Down
34 changes: 27 additions & 7 deletions src/poli/objective_repository/rasp/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class RaspBlackBox(AbstractBlackBox):
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s), by default None.
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
Expand Down Expand Up @@ -83,6 +88,7 @@ class RaspBlackBox(AbstractBlackBox):
def __init__(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -99,6 +105,18 @@ def __init__(
-----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive : bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
The temporary folder path, by default None.
batch_size : int, optional
The batch size for parallel evaluation, by default None.
parallelize : bool, optional
Expand All @@ -107,13 +125,6 @@ def __init__(
The number of workers for parallel evaluation, by default None.
evaluation_budget : int, optional
The evaluation budget, by default float("inf").
chains_to_keep : List[str], optional
The chains to keep in the PDB file(s), by default we
keep the chain "A" for all pdbs passed.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
The temporary folder path, by default None.
Raises:
-------
Expand Down Expand Up @@ -147,12 +158,14 @@ def __init__(
self.chains_to_keep = chains_to_keep
self.experiment_id = experiment_id
self.tmp_folder = tmp_folder
self.additive = additive
self.inner_function = get_inner_function(
isolated_function_name="rasp__isolated",
class_name="RaspIsolatedLogic",
module_to_import="poli.objective_repository.rasp.isolated_function",
force_isolation=self.force_isolation,
wildtype_pdb_path=self.wildtype_pdb_path,
additive=self.additive,
chains_to_keep=self.chains_to_keep,
experiment_id=self.experiment_id,
tmp_folder=self.tmp_folder,
Expand Down Expand Up @@ -207,6 +220,7 @@ class RaspProblemFactory(AbstractProblemFactory):
def create(
self,
wildtype_pdb_path: Union[Path, List[Path]],
additive: bool = False,
chains_to_keep: List[str] = None,
experiment_id: str = None,
tmp_folder: Path = None,
Expand All @@ -225,6 +239,11 @@ def create(
----------
wildtype_pdb_path : Union[Path, List[Path]]
The path(s) to the wildtype PDB file(s).
additive: bool, optional
Whether we treat multiple mutations as additive, by default False.
If you are interested in running this black box with multiple
mutations, you should set this to True. Otherwise, it will
raise an error if you pass a sequence with more than one mutation.
experiment_id : str, optional
The experiment ID, by default None.
tmp_folder : Path, optional
Expand Down Expand Up @@ -276,6 +295,7 @@ def create(

f = RaspBlackBox(
wildtype_pdb_path=wildtype_pdb_path,
additive=additive,
chains_to_keep=chains_to_keep,
experiment_id=experiment_id,
tmp_folder=tmp_folder,
Expand Down
42 changes: 33 additions & 9 deletions src/poli/tests/registry/proteins/test_rasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,7 @@


@pytest.mark.poli__rasp
def test_rasp_on_3ned_against_notebooks_results_on_rasp_env():
import torch

# For us to match what the notebook says, we have
# to run at double precision.
torch.set_default_dtype(torch.float64)

# If the previous import was successful, we can
# create a RaSP problem:
def test_rasp_on_3ned_against_notebooks_results():
problem = objective_factory.create(
name="rasp",
wildtype_pdb_path=THIS_DIR / "3ned.pdb",
Expand Down Expand Up @@ -53,6 +45,7 @@ def test_rasp_on_3ned_against_notebooks_results_isolated():
problem = objective_factory.create(
name="rasp",
wildtype_pdb_path=THIS_DIR / "3ned.pdb",
force_isolation=True,
)
f, x0 = problem.black_box, problem.x0

Expand Down Expand Up @@ -80,3 +73,34 @@ def test_rasp_on_3ned_against_notebooks_results_isolated():
assert np.isclose(y[0], 0.0365, atol=1e-4)
assert np.isclose(y[1], -0.07091, atol=1e-4)
assert np.isclose(y[2], -0.283559, atol=1e-4)


@pytest.mark.poli__rasp
def test_rasp_using_additive_flag_on_two_mutations():
problem = objective_factory.create(
name="rasp",
wildtype_pdb_path=THIS_DIR / "3ned.pdb",
additive=True,
)
f, x0 = problem.black_box, problem.x0

wildtype_sequence = "".join(x0[0])
one_mutant_with_two_mutations = [
"AR" + wildtype_sequence[2:],
]
two_mutations = [
"A" + wildtype_sequence[1:],
wildtype_sequence[:1] + "R" + wildtype_sequence[2:],
]

x = np.array([list(mutation) for mutation in one_mutant_with_two_mutations])
y = f(x)

x1 = np.array([list(mutation) for mutation in two_mutations])
y1 = f(x1)

assert y == y1.sum()


if __name__ == "__main__":
test_rasp_using_additive_flag_on_two_mutations()

0 comments on commit c2e52da

Please sign in to comment.