Skip to content

Commit

Permalink
Adds a test for initializing with csv observer init info
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Dec 2, 2024
1 parent 85e4d09 commit 204b76c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 11 deletions.
6 changes: 6 additions & 0 deletions src/poli/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ class FoldXNotFoundException(PoliException):
"""Exception raised when FoldX wasn't found in ~/foldx/foldx."""

pass


class ObserverNotInitializedError(PoliException):
"""Exception raised when the observer is not initialized."""

pass
66 changes: 56 additions & 10 deletions src/poli/core/util/observers/csv_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np

from poli.core.black_box_information import BlackBoxInformation
from poli.core.exceptions import ObserverNotInitializedError
from poli.core.util.abstract_observer import AbstractObserver


Expand All @@ -19,30 +21,63 @@ class CSVObserverInitInfo:


class CSVObserver(AbstractObserver):
"""
A simple observer that logs to a CSV file, appending rows on each query.
"""

def __init__(self):
self.has_been_initialized = False
super().__init__()

def initialize_observer(
self,
problem_setup_info: object,
caller_info: CSVObserverInitInfo,
problem_setup_info: BlackBoxInformation,
caller_info: CSVObserverInitInfo | dict,
seed: int,
) -> object:
"""
Initializes the observer with the given information.
Parameters
----------
black_box_info : BlackBoxInformation
The information about the black box.
caller_info : dict | CSVObserverInitInfo
Information used for logging. If a dictionary, it should contain the
keys `experiment_id` and `experiment_path`.
seed : int
The seed used for the experiment. This is only logged, not used.
"""
self.info = problem_setup_info
self.seed = seed
self.unique_id = f"{uuid4()}"[:8]
self.experiment_id = caller_info.get(
"experiment_id",
f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}",
)
self.experiment_path = Path(

if isinstance(caller_info, CSVObserverInitInfo):
caller_info = caller_info.__dict__

self.all_results_path = Path(
caller_info.get("experiment_path", "./poli_results")
)
self.experiment_path = self.all_results_path / problem_setup_info.name
self.experiment_path.mkdir(exist_ok=True, parents=True)
self._write_gitignore()

if not (self.experiment_path / ".gitignore").exists():
with open(self.experiment_path / ".gitignore", "w") as f:
f.write("*\n")
self.experiment_id = caller_info.get(
"experiment_id",
f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}",
)

self.csv_file_path = self.experiment_path / f"{self.experiment_id}.csv"
self.save_header()
self.has_been_initialized = True

def _write_gitignore(self):
if not (self.all_results_path / ".gitignore").exists():
with open(self.all_results_path / ".gitignore", "w") as f:
f.write("*\n")

def _make_folder_for_experiment(self):
self.experiment_path.mkdir(exist_ok=True, parents=True)

def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None:
if x.ndim != 2:
Expand All @@ -54,11 +89,22 @@ def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None:
f"x and y should have the same number of samples, got {x.shape[0]} and {y.shape[0]} respectively."
)

def _ensure_proper_shape(self, x: np.ndarray) -> np.ndarray:
if x.ndim == 1:
return x.reshape(-1, 1)
return x

def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
if not self.has_been_initialized:
raise ObserverNotInitializedError(
"The observer has not been initialized. Please call `initialize_observer` first."
)
x = self._ensure_proper_shape(x)
self._validate_input(x, y)
self.append_results(["".join(x_i) for x_i in x], [y_i for y_i in y.flatten()])

def save_header(self):
self._make_folder_for_experiment()
with open(self.csv_file_path, "w") as f:
f.write("x,y\n")

Expand Down
41 changes: 40 additions & 1 deletion src/poli/tests/observers/test_csv_observer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import csv

import numpy as np
import pytest

from poli.core.util.observers.csv_observer import CSVObserver
from poli.core.exceptions import ObserverNotInitializedError
from poli.core.util.observers.csv_observer import CSVObserver, CSVObserverInitInfo
from poli.repository import AlohaBlackBox


Expand Down Expand Up @@ -63,3 +65,40 @@ def test_csv_observer_works_with_incomplete_caller_info():
assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0
assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0
assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0


def test_observer_without_initialization():
f = AlohaBlackBox()
observer = CSVObserver()

f.set_observer(observer)

with pytest.raises(ObserverNotInitializedError):
f(np.array([list("MIGUE")]))


def test_works_with_csv_init_object():
f = AlohaBlackBox()
observer = CSVObserver()
observer.initialize_observer(
f.info,
CSVObserverInitInfo(
experiment_id="test_csv_observer_logs_on_aloha",
experiment_path="./poli_results",
),
seed=0,
)
f.set_observer(observer)
f(np.array([list("MIGUE")]))
f(np.array([list("ALOOF")]))
f(np.array([list("ALOHA"), list("OMAHA")]))
assert observer.csv_file_path.exists()
# Loading up the csv and checking results
with open(observer.csv_file_path, "r") as f:
reader = csv.reader(f)
results = list(reader)
assert results[0] == ["x", "y"]
assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0
assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0
assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0
assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0

0 comments on commit 204b76c

Please sign in to comment.