Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Ehrlich using Holo-bench #274

Merged
merged 14 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/python-tox-testing-ehrlich-holo-env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: poli ehrlich (py3.10)

on:
push:
branches:
- dev
- master
pull_request:
types: [opened, synchronize, reopened, ready_for_review, closed]
branches:
- dev
- master
schedule:
- cron: '0 0 * * 0'

jobs:
build-linux:
runs-on: ubuntu-latest
timeout-minutes: 8
if: github.event.pull_request.draft == false
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
python -m pip install tox
- name: Test Ehrlich black boxes with tox and pytest
run: |
tox -c tox.ini -e poli-ehrlich-holo-py310
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ protein = [
"python-levenshtein",
"pdb-tools",
]
ehrlich_holo = [
"pytorch-holo",
]
tdc = [
"pytdc",
]
Expand All @@ -62,6 +65,7 @@ markers = [
"poli__protein: marks tests that run in the poli__protein environment",
"poli__rasp: marks tests that run in the poli__rasp environment",
"poli__rmf: marks tests that run in poli__rmf environment",
"poli__ehrlich_holo: marks tests that run in poli__ehrlich_holo environment",
"unmarked: All other tests, which usually run in the base environment",
]

Expand Down
3 changes: 2 additions & 1 deletion src/poli/core/util/isolation/instancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def get_inner_function(
isolated_function_name: str,
class_name: str,
module_to_import: str,
seed: int | None = None,
force_isolation: bool = False,
quiet: bool = False,
**kwargs,
Expand Down Expand Up @@ -360,10 +359,12 @@ class from the sibling isolated_function.py file of each register.py.
InnerFunctionClass = getattr(module, class_name)
inner_function = InnerFunctionClass(**kwargs)
except ImportError:
seed = kwargs.pop("seed", None)
inner_function = instance_function_as_isolated_process(
name=isolated_function_name, seed=seed, quiet=quiet, **kwargs
)
else:
seed = kwargs.pop("seed", None)
inner_function = instance_function_as_isolated_process(
name=isolated_function_name, seed=seed, quiet=quiet, **kwargs
)
Expand Down
3 changes: 3 additions & 0 deletions src/poli/objective_repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# Discrete toy examples
from .ehrlich.register import EhrlichBlackBox, EhrlichProblemFactory
from .ehrlich_holo.register import EhrlichHoloBlackBox, EhrlichHoloProblemFactory
from .fexofenadine_mpo.register import (
FexofenadineMPOBlackBox,
FexofenadineMPOProblemFactory,
Expand Down Expand Up @@ -129,6 +130,7 @@
AVAILABLE_PROBLEM_FACTORIES = {
"aloha": AlohaProblemFactory,
"ehrlich": EhrlichProblemFactory,
"ehrlich_holo": EhrlichHoloProblemFactory,
"dockstring": DockstringProblemFactory,
"drd3_docking": DRD3ProblemFactory,
"foldx_rfp_lambo": FoldXRFPLamboProblemFactory,
Expand Down Expand Up @@ -174,6 +176,7 @@
AVAILABLE_BLACK_BOXES = {
"aloha": AlohaBlackBox,
"ehrlich": EhrlichBlackBox,
"ehrlich_holo": EhrlichHoloBlackBox,
"dockstring": DockstringBlackBox,
"drd3_docking": DRD3BlackBox,
"foldx_rfp_lambo": FoldXRFPLamboBlackBox,
Expand Down
5 changes: 5 additions & 0 deletions src/poli/objective_repository/ehrlich_holo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""A closed-form black box simulating epistatic effects."""

from .register import EhrlichHoloBlackBox, EhrlichHoloProblemFactory

__all__ = ["EhrlichHoloBlackBox", "EhrlichHoloProblemFactory"]
9 changes: 9 additions & 0 deletions src/poli/objective_repository/ehrlich_holo/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: poli__ehrlich
channels:
- defaults
dependencies:
- python=3.10
- pip
- pip:
- "git+https://github.com/MachineLearningLifeScience/poli.git@dev"
- pytorch-holo
84 changes: 84 additions & 0 deletions src/poli/objective_repository/ehrlich_holo/isolated_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import numpy as np
import torch
from holo.test_functions.closed_form._ehrlich import Ehrlich

from poli.core.abstract_isolated_function import AbstractIsolatedFunction
from poli.core.registry import register_isolated_function
from poli.core.util.proteins.defaults import AMINO_ACIDS


class EhrlichIsolatedLogic(AbstractIsolatedFunction):
""" """
miguelgondu marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
sequence_length: int,
motif_length: int,
n_motifs: int,
quantization: int | None = None,
noise_std: float = 0.0,
seed: int | None = None,
epistasis_factor: float = 0.0,
return_value_on_unfeasible: float = -np.inf,
miguelgondu marked this conversation as resolved.
Show resolved Hide resolved
alphabet: list[str] = AMINO_ACIDS,
parallelize: bool = False,
num_workers: int = None,
evaluation_budget: int = float("inf"),
):
self.sequence_length = sequence_length
self.motif_length = motif_length
self.n_motifs = n_motifs
self.epistasis_factor = epistasis_factor

if seed is None:
raise ValueError("The seed parameter must be set.")

# if quantization is None:
# self.quantization = motif_length

# if not (1 <= quantization <= motif_length) or motif_length % quantization != 0:
# raise ValueError(
# "The quantization parameter must be between 1 and the motif length, "
# "and the motif length must be divisible by the quantization."
# )

self.noise_std = noise_std
self.quantization = quantization
self.seed = seed
self.return_value_on_unfeasible = return_value_on_unfeasible
self.alphabet = alphabet
self.parallelize = parallelize
self.num_workers = num_workers
self.evaluation_budget = evaluation_budget

self.inner_ehrlich = Ehrlich(
num_states=len(alphabet),
dim=sequence_length,
num_motifs=n_motifs,
motif_length=motif_length,
quantization=quantization,
noise_std=noise_std,
negate=False, # We aim to maximize the function
random_seed=seed,
)

def __call__(self, x: np.ndarray, context: None) -> np.ndarray:
# First, we transform the strings into integers using the alphabet
batch_size = x.shape[0]
x_ = np.array([[self.alphabet.index(c) for c in s] for s in x.flatten()])

return (
self.inner_ehrlich(torch.from_numpy(x_))
.numpy(force=True)
.reshape(batch_size, 1)
)


if __name__ == "__main__":
register_isolated_function(
EhrlichIsolatedLogic,
name="ehrlich_holo__isolated",
conda_environment_name="poli__ehrlich",
)
Loading
Loading