Skip to content

Commit

Permalink
allows the user to set the mlflow tracking uri
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Bartels committed Nov 21, 2024
1 parent 331e319 commit fdfe39a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/poli/core/util/observers/mlflow_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from poli.core.black_box_information import BlackBoxInformation
from poli.core.util.abstract_observer import AbstractObserver

TRACKING_URI = "tracking_uri"
OBJECTIVE = "OBJECTIVE"
SEQUENCE = "SEQUENCE"
SEED = "SEED"
Expand All @@ -16,9 +17,10 @@ class MLFlowObserver(AbstractObserver):
This observer uses mlflow as a backend.
"""

def __init__(self, tracking_uri: Path):
def __init__(self, tracking_uri: Path = None):
self.step = 0
mlflow.set_tracking_uri(tracking_uri)
if tracking_uri is not None:
mlflow.set_tracking_uri(tracking_uri)

def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
for n in range(y.shape[0]):
Expand All @@ -36,9 +38,13 @@ def log(self, algorithm_info: dict):
def initialize_observer(
self,
problem_setup_info: BlackBoxInformation,
caller_info: object,
caller_info: dict,
seed: int,
) -> object:
tracking_uri = caller_info.pop(TRACKING_URI, None)
if tracking_uri is not None:
mlflow.set_tracking_uri(tracking_uri)

experiment = mlflow.set_experiment(
experiment_name=problem_setup_info.get_problem_name()
)
Expand Down

0 comments on commit fdfe39a

Please sign in to comment.