diff --git a/aepsych/strategy/strategy.py b/aepsych/strategy/strategy.py index 255e810f3..4534f6fc5 100644 --- a/aepsych/strategy/strategy.py +++ b/aepsych/strategy/strategy.py @@ -9,6 +9,7 @@ import warnings from typing import List, Mapping, Optional, Tuple, Union +from copy import deepcopy import numpy as np import torch @@ -56,6 +57,7 @@ def __init__( name: str = "", run_indefinitely: bool = False, transforms: ChainedInputTransform = ChainedInputTransform(**{}), + copy_model: bool = False, ) -> None: """Initialize the strategy object. @@ -90,6 +92,9 @@ def __init__( should be defined in raw parameter space for initialization. However, if the lb/ub attribute are access from an initialized Strategy object, it will be returned in transformed space. + copy_model (bool): Whether to do any model-related methods on a + copy or the original. Used for multi-client strategies. Defaults + to False. """ self.is_finished = False @@ -116,8 +121,8 @@ def __init__( len(outcome_types) == 1 and outcome_types[0] == model.outcome_type ), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!" else: - assert ( - set(outcome_types) == set(model.outcome_type) + assert set(outcome_types) == set( + model.outcome_type ), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!" if use_gpu_modeling: @@ -160,6 +165,7 @@ def __init__( self.min_total_outcome_occurrences = min_total_outcome_occurrences self.max_asks = max_asks or generator.max_asks self.keep_most_recent = keep_most_recent + self.copy_model = copy_model self.transforms = transforms if self.transforms is not None: @@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor: self.model.to(self.generator_device) # type: ignore self._count = self._count + num_points - points = self.generator.gen(num_points, self.model, **kwargs) + model = deepcopy(self.model) if self.copy_model else self.model + points = self.generator.gen(num_points, model, **kwargs) if original_device is not None: self.model.to(original_device) # type: ignore @@ -295,9 +302,9 @@ def get_max( self.model is not None ), "model is None! Cannot get the max without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_max( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -324,9 +331,9 @@ def get_min( self.model is not None ), "model is None! Cannot get the min without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = get_min( - self.model, + model, self.bounds, locked_dims=constraints, probability_space=probability_space, @@ -358,9 +365,9 @@ def inv_query( self.model is not None ), "model is None! Cannot get the inv_query without a model!" self.model.to(self.model_device) - + model = deepcopy(self.model) if self.copy_model else self.model val, arg = inv_query( - model=self.model, + model=model, y=y, bounds=self.bounds, locked_dims=constraints, @@ -383,7 +390,8 @@ def predict(self, x: torch.Tensor, probability_space: bool = False) -> torch.Ten """ assert self.model is not None, "model is None! Cannot predict without a model!" self.model.to(self.model_device) - return self.model.predict(x=x, probability_space=probability_space) + model = deepcopy(self.model) if self.copy_model else self.model + return model.predict(x=x, probability_space=probability_space) @ensure_model_is_fresh def get_jnd( @@ -398,8 +406,9 @@ def get_jnd( self.model is not None ), "model is None! Cannot get the get jnd without a model!" self.model.to(self.model_device) + model = deepcopy(self.model) if self.copy_model else self.model return get_jnd( # type: ignore - model=self.model, lb=self.lb, ub=self.ub, dim=self.dim, *args, **kwargs + model=model, lb=self.lb, ub=self.ub, dim=self.dim, *args, **kwargs ) @ensure_model_is_fresh @@ -417,7 +426,8 @@ def sample( """ assert self.model is not None, "model is None! Cannot sample without a model!" self.model.to(self.model_device) - return self.model.sample(x, num_samples=num_samples) + model = deepcopy(self.model) if self.copy_model else self.model + return model.sample(x, num_samples=num_samples) def finish(self) -> None: """Finish the strategy.""" @@ -459,7 +469,8 @@ def finished(self) -> bool: assert ( self.model is not None ), "model is None! Cannot predict without a model!" - fmean, _ = self.model.predict(self.eval_grid, probability_space=True) + model = deepcopy(self.model) if self.copy_model else self.model + fmean, _ = model.predict(self.eval_grid, probability_space=True) meets_post_range = ( (fmean.max() - fmean.min()) >= self.min_post_range ).item() @@ -534,9 +545,10 @@ def fit(self) -> None: """Fit the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.fit( # type: ignore + model.fit( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -546,11 +558,12 @@ def fit(self) -> None: ) else: try: - self.model.fit(self.x, self.y) # type: ignore + model.fit(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) @@ -558,9 +571,10 @@ def update(self) -> None: """Update the model.""" if self.can_fit: self.model.to(self.model_device) # type: ignore + model = deepcopy(self.model) if self.copy_model else self.model if self.keep_most_recent is not None: try: - self.model.update( # type: ignore + model.update( # type: ignore self.x[-self.keep_most_recent :], # type: ignore self.y[-self.keep_most_recent :], # type: ignore ) @@ -570,11 +584,13 @@ def update(self) -> None: ) else: try: - self.model.update(self.x, self.y) # type: ignore + model.update(self.x, self.y) # type: ignore except ModelFittingError: logger.warning( "Failed to fit model! Predictions may not be accurate!" ) + + self.model = model else: warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning) @@ -656,6 +672,8 @@ def from_config(cls, config: Config, name: str) -> Strategy: ) min_asks = n_trials + copy_model = config.getboolean(name, "copy_model", fallback=False) + return cls( lb=lb, ub=ub, @@ -673,5 +691,6 @@ def from_config(cls, config: Config, name: str) -> Strategy: min_post_range=min_post_range, keep_most_recent=keep_most_recent, min_total_tells=min_total_tells, + copy_model=copy_model, name=name, )