From 204b76c185a3dce22135d482c22ec5f5beb12c25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 2 Dec 2024 23:24:06 +0100 Subject: [PATCH] Adds a test for initializing with csv observer init info --- src/poli/core/exceptions.py | 6 ++ src/poli/core/util/observers/csv_observer.py | 66 ++++++++++++++++--- src/poli/tests/observers/test_csv_observer.py | 41 +++++++++++- 3 files changed, 102 insertions(+), 11 deletions(-) diff --git a/src/poli/core/exceptions.py b/src/poli/core/exceptions.py index 38c372c4..d960033c 100644 --- a/src/poli/core/exceptions.py +++ b/src/poli/core/exceptions.py @@ -17,3 +17,9 @@ class FoldXNotFoundException(PoliException): """Exception raised when FoldX wasn't found in ~/foldx/foldx.""" pass + + +class ObserverNotInitializedError(PoliException): + """Exception raised when the observer is not initialized.""" + + pass diff --git a/src/poli/core/util/observers/csv_observer.py b/src/poli/core/util/observers/csv_observer.py index 84e68e79..2f39778a 100644 --- a/src/poli/core/util/observers/csv_observer.py +++ b/src/poli/core/util/observers/csv_observer.py @@ -7,6 +7,8 @@ import numpy as np +from poli.core.black_box_information import BlackBoxInformation +from poli.core.exceptions import ObserverNotInitializedError from poli.core.util.abstract_observer import AbstractObserver @@ -19,30 +21,63 @@ class CSVObserverInitInfo: class CSVObserver(AbstractObserver): + """ + A simple observer that logs to a CSV file, appending rows on each query. + """ + + def __init__(self): + self.has_been_initialized = False + super().__init__() + def initialize_observer( self, - problem_setup_info: object, - caller_info: CSVObserverInitInfo, + problem_setup_info: BlackBoxInformation, + caller_info: CSVObserverInitInfo | dict, seed: int, ) -> object: + """ + Initializes the observer with the given information. + + Parameters + ---------- + black_box_info : BlackBoxInformation + The information about the black box. + caller_info : dict | CSVObserverInitInfo + Information used for logging. If a dictionary, it should contain the + keys `experiment_id` and `experiment_path`. + seed : int + The seed used for the experiment. This is only logged, not used. + """ self.info = problem_setup_info self.seed = seed self.unique_id = f"{uuid4()}"[:8] - self.experiment_id = caller_info.get( - "experiment_id", - f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}", - ) - self.experiment_path = Path( + + if isinstance(caller_info, CSVObserverInitInfo): + caller_info = caller_info.__dict__ + + self.all_results_path = Path( caller_info.get("experiment_path", "./poli_results") ) + self.experiment_path = self.all_results_path / problem_setup_info.name self.experiment_path.mkdir(exist_ok=True, parents=True) + self._write_gitignore() - if not (self.experiment_path / ".gitignore").exists(): - with open(self.experiment_path / ".gitignore", "w") as f: - f.write("*\n") + self.experiment_id = caller_info.get( + "experiment_id", + f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}", + ) self.csv_file_path = self.experiment_path / f"{self.experiment_id}.csv" self.save_header() + self.has_been_initialized = True + + def _write_gitignore(self): + if not (self.all_results_path / ".gitignore").exists(): + with open(self.all_results_path / ".gitignore", "w") as f: + f.write("*\n") + + def _make_folder_for_experiment(self): + self.experiment_path.mkdir(exist_ok=True, parents=True) def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None: if x.ndim != 2: @@ -54,11 +89,22 @@ def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None: f"x and y should have the same number of samples, got {x.shape[0]} and {y.shape[0]} respectively." ) + def _ensure_proper_shape(self, x: np.ndarray) -> np.ndarray: + if x.ndim == 1: + return x.reshape(-1, 1) + return x + def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None: + if not self.has_been_initialized: + raise ObserverNotInitializedError( + "The observer has not been initialized. Please call `initialize_observer` first." + ) + x = self._ensure_proper_shape(x) self._validate_input(x, y) self.append_results(["".join(x_i) for x_i in x], [y_i for y_i in y.flatten()]) def save_header(self): + self._make_folder_for_experiment() with open(self.csv_file_path, "w") as f: f.write("x,y\n") diff --git a/src/poli/tests/observers/test_csv_observer.py b/src/poli/tests/observers/test_csv_observer.py index 354042d5..a34cbfd8 100644 --- a/src/poli/tests/observers/test_csv_observer.py +++ b/src/poli/tests/observers/test_csv_observer.py @@ -1,8 +1,10 @@ import csv import numpy as np +import pytest -from poli.core.util.observers.csv_observer import CSVObserver +from poli.core.exceptions import ObserverNotInitializedError +from poli.core.util.observers.csv_observer import CSVObserver, CSVObserverInitInfo from poli.repository import AlohaBlackBox @@ -63,3 +65,40 @@ def test_csv_observer_works_with_incomplete_caller_info(): assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0 assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0 assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0 + + +def test_observer_without_initialization(): + f = AlohaBlackBox() + observer = CSVObserver() + + f.set_observer(observer) + + with pytest.raises(ObserverNotInitializedError): + f(np.array([list("MIGUE")])) + + +def test_works_with_csv_init_object(): + f = AlohaBlackBox() + observer = CSVObserver() + observer.initialize_observer( + f.info, + CSVObserverInitInfo( + experiment_id="test_csv_observer_logs_on_aloha", + experiment_path="./poli_results", + ), + seed=0, + ) + f.set_observer(observer) + f(np.array([list("MIGUE")])) + f(np.array([list("ALOOF")])) + f(np.array([list("ALOHA"), list("OMAHA")])) + assert observer.csv_file_path.exists() + # Loading up the csv and checking results + with open(observer.csv_file_path, "r") as f: + reader = csv.reader(f) + results = list(reader) + assert results[0] == ["x", "y"] + assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0 + assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0 + assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0 + assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0