diff --git a/src/poli/core/exceptions.py b/src/poli/core/exceptions.py index 38c372c4..d960033c 100644 --- a/src/poli/core/exceptions.py +++ b/src/poli/core/exceptions.py @@ -17,3 +17,9 @@ class FoldXNotFoundException(PoliException): """Exception raised when FoldX wasn't found in ~/foldx/foldx.""" pass + + +class ObserverNotInitializedError(PoliException): + """Exception raised when the observer is not initialized.""" + + pass diff --git a/src/poli/core/util/observers/csv_observer.py b/src/poli/core/util/observers/csv_observer.py new file mode 100644 index 00000000..2f39778a --- /dev/null +++ b/src/poli/core/util/observers/csv_observer.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from time import time +from uuid import uuid4 + +import numpy as np + +from poli.core.black_box_information import BlackBoxInformation +from poli.core.exceptions import ObserverNotInitializedError +from poli.core.util.abstract_observer import AbstractObserver + + +@dataclass +class CSVObserverInitInfo: + """Initialization information for the CSVObserver.""" + + experiment_id: str + experiment_path: str | Path = "./poli_results" + + +class CSVObserver(AbstractObserver): + """ + A simple observer that logs to a CSV file, appending rows on each query. + """ + + def __init__(self): + self.has_been_initialized = False + super().__init__() + + def initialize_observer( + self, + problem_setup_info: BlackBoxInformation, + caller_info: CSVObserverInitInfo | dict, + seed: int, + ) -> object: + """ + Initializes the observer with the given information. + + Parameters + ---------- + black_box_info : BlackBoxInformation + The information about the black box. + caller_info : dict | CSVObserverInitInfo + Information used for logging. If a dictionary, it should contain the + keys `experiment_id` and `experiment_path`. + seed : int + The seed used for the experiment. This is only logged, not used. + """ + self.info = problem_setup_info + self.seed = seed + self.unique_id = f"{uuid4()}"[:8] + + if isinstance(caller_info, CSVObserverInitInfo): + caller_info = caller_info.__dict__ + + self.all_results_path = Path( + caller_info.get("experiment_path", "./poli_results") + ) + self.experiment_path = self.all_results_path / problem_setup_info.name + self.experiment_path.mkdir(exist_ok=True, parents=True) + self._write_gitignore() + + self.experiment_id = caller_info.get( + "experiment_id", + f"{int(time())}_experiment_{problem_setup_info.name}_{seed}_{self.unique_id}", + ) + + self.csv_file_path = self.experiment_path / f"{self.experiment_id}.csv" + self.save_header() + self.has_been_initialized = True + + def _write_gitignore(self): + if not (self.all_results_path / ".gitignore").exists(): + with open(self.all_results_path / ".gitignore", "w") as f: + f.write("*\n") + + def _make_folder_for_experiment(self): + self.experiment_path.mkdir(exist_ok=True, parents=True) + + def _validate_input(self, x: np.ndarray, y: np.ndarray) -> None: + if x.ndim != 2: + raise ValueError(f"x should be 2D, got {x.ndim}D instead.") + if y.ndim != 2: + raise ValueError(f"y should be 2D, got {y.ndim}D instead.") + if x.shape[0] != y.shape[0]: + raise ValueError( + f"x and y should have the same number of samples, got {x.shape[0]} and {y.shape[0]} respectively." + ) + + def _ensure_proper_shape(self, x: np.ndarray) -> np.ndarray: + if x.ndim == 1: + return x.reshape(-1, 1) + return x + + def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None: + if not self.has_been_initialized: + raise ObserverNotInitializedError( + "The observer has not been initialized. Please call `initialize_observer` first." + ) + x = self._ensure_proper_shape(x) + self._validate_input(x, y) + self.append_results(["".join(x_i) for x_i in x], [y_i for y_i in y.flatten()]) + + def save_header(self): + self._make_folder_for_experiment() + with open(self.csv_file_path, "w") as f: + f.write("x,y\n") + + def append_results(self, x: list[str], y: list[float]): + with open(self.csv_file_path, "a") as f: + for x_i, y_i in zip(x, y): + f.write(f"{x_i},{y_i}\n") diff --git a/src/poli/tests/observers/test_csv_observer.py b/src/poli/tests/observers/test_csv_observer.py new file mode 100644 index 00000000..a34cbfd8 --- /dev/null +++ b/src/poli/tests/observers/test_csv_observer.py @@ -0,0 +1,104 @@ +import csv + +import numpy as np +import pytest + +from poli.core.exceptions import ObserverNotInitializedError +from poli.core.util.observers.csv_observer import CSVObserver, CSVObserverInitInfo +from poli.repository import AlohaBlackBox + + +def test_csv_observer_logs_on_aloha(): + f = AlohaBlackBox() + observer = CSVObserver() + observer.initialize_observer( + f.info, + { + "experiment_id": "test_csv_observer_logs_on_aloha", + "experiment_path": "./poli_results", + }, + seed=0, + ) + + f.set_observer(observer) + f(np.array([list("MIGUE")])) + f(np.array([list("ALOOF")])) + f(np.array([list("ALOHA"), list("OMAHA")])) + + assert observer.csv_file_path.exists() + + # Loading up the csv and checking results + with open(observer.csv_file_path, "r") as f: + reader = csv.reader(f) + results = list(reader) + + assert results[0] == ["x", "y"] + assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0 + assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0 + assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0 + assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0 + + +def test_csv_observer_works_with_incomplete_caller_info(): + f = AlohaBlackBox() + observer = CSVObserver() + observer.initialize_observer( + f.info, + {}, + seed=0, + ) + + f.set_observer(observer) + f(np.array([list("MIGUE")])) + f(np.array([list("ALOOF")])) + f(np.array([list("ALOHA"), list("OMAHA")])) + + assert observer.csv_file_path.exists() + + # Loading up the csv and checking results + with open(observer.csv_file_path, "r") as f: + reader = csv.reader(f) + results = list(reader) + + assert results[0] == ["x", "y"] + assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0 + assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0 + assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0 + assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0 + + +def test_observer_without_initialization(): + f = AlohaBlackBox() + observer = CSVObserver() + + f.set_observer(observer) + + with pytest.raises(ObserverNotInitializedError): + f(np.array([list("MIGUE")])) + + +def test_works_with_csv_init_object(): + f = AlohaBlackBox() + observer = CSVObserver() + observer.initialize_observer( + f.info, + CSVObserverInitInfo( + experiment_id="test_csv_observer_logs_on_aloha", + experiment_path="./poli_results", + ), + seed=0, + ) + f.set_observer(observer) + f(np.array([list("MIGUE")])) + f(np.array([list("ALOOF")])) + f(np.array([list("ALOHA"), list("OMAHA")])) + assert observer.csv_file_path.exists() + # Loading up the csv and checking results + with open(observer.csv_file_path, "r") as f: + reader = csv.reader(f) + results = list(reader) + assert results[0] == ["x", "y"] + assert results[1][0] == "MIGUE" and float(results[1][1]) == 0.0 + assert results[2][0] == "ALOOF" and float(results[2][1]) == 3.0 + assert results[3][0] == "ALOHA" and float(results[3][1]) == 5.0 + assert results[4][0] == "OMAHA" and float(results[4][1]) == 2.0