Skip to content

Commit

Permalink
strategy methods act on copies of models if desired
Browse files Browse the repository at this point in the history
Summary: Add extra flag to strategy to allow all model-related operations to act on a copy (then save that copy if needed). Used for multi-client server.

Test Plan:
  • Loading branch information
JasonKChow committed Feb 6, 2025
1 parent 950a7a3 commit 075e40d
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions aepsych/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import warnings
from typing import List, Mapping, Optional, Tuple, Union
from copy import deepcopy

import numpy as np
import torch
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
name: str = "",
run_indefinitely: bool = False,
transforms: ChainedInputTransform = ChainedInputTransform(**{}),
copy_model: bool = False,
) -> None:
"""Initialize the strategy object.
Expand Down Expand Up @@ -90,6 +92,9 @@ def __init__(
should be defined in raw parameter space for initialization. However,
if the lb/ub attribute are access from an initialized Strategy object,
it will be returned in transformed space.
copy_model (bool): Whether to do any model-related methods on a
copy or the original. Used for multi-client strategies. Defaults
to False.
"""
self.is_finished = False

Expand All @@ -116,8 +121,8 @@ def __init__(
len(outcome_types) == 1 and outcome_types[0] == model.outcome_type
), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"
else:
assert (
set(outcome_types) == set(model.outcome_type)
assert set(outcome_types) == set(
model.outcome_type
), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"

if use_gpu_modeling:
Expand Down Expand Up @@ -160,6 +165,7 @@ def __init__(
self.min_total_outcome_occurrences = min_total_outcome_occurrences
self.max_asks = max_asks or generator.max_asks
self.keep_most_recent = keep_most_recent
self.copy_model = copy_model

self.transforms = transforms
if self.transforms is not None:
Expand Down Expand Up @@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
self.model.to(self.generator_device) # type: ignore

self._count = self._count + num_points
points = self.generator.gen(num_points, self.model, **kwargs)
model = deepcopy(self.model) if self.copy_model else self.model
points = self.generator.gen(num_points, model, **kwargs)

if original_device is not None:
self.model.to(original_device) # type: ignore
Expand Down Expand Up @@ -295,9 +302,9 @@ def get_max(
self.model is not None
), "model is None! Cannot get the max without a model!"
self.model.to(self.model_device)

model = deepcopy(self.model) if self.copy_model else self.model
val, arg = get_max(
self.model,
model,
self.bounds,
locked_dims=constraints,
probability_space=probability_space,
Expand All @@ -324,9 +331,9 @@ def get_min(
self.model is not None
), "model is None! Cannot get the min without a model!"
self.model.to(self.model_device)

model = deepcopy(self.model) if self.copy_model else self.model
val, arg = get_min(
self.model,
model,
self.bounds,
locked_dims=constraints,
probability_space=probability_space,
Expand Down Expand Up @@ -358,9 +365,9 @@ def inv_query(
self.model is not None
), "model is None! Cannot get the inv_query without a model!"
self.model.to(self.model_device)

model = deepcopy(self.model) if self.copy_model else self.model
val, arg = inv_query(
model=self.model,
model=model,
y=y,
bounds=self.bounds,
locked_dims=constraints,
Expand All @@ -383,7 +390,8 @@ def predict(self, x: torch.Tensor, probability_space: bool = False) -> torch.Ten
"""
assert self.model is not None, "model is None! Cannot predict without a model!"
self.model.to(self.model_device)
return self.model.predict(x=x, probability_space=probability_space)
model = deepcopy(self.model) if self.copy_model else self.model
return model.predict(x=x, probability_space=probability_space)

@ensure_model_is_fresh
def get_jnd(
Expand All @@ -398,8 +406,9 @@ def get_jnd(
self.model is not None
), "model is None! Cannot get the get jnd without a model!"
self.model.to(self.model_device)
model = deepcopy(self.model) if self.copy_model else self.model
return get_jnd( # type: ignore
model=self.model, lb=self.lb, ub=self.ub, dim=self.dim, *args, **kwargs
model=model, lb=self.lb, ub=self.ub, dim=self.dim, *args, **kwargs
)

@ensure_model_is_fresh
Expand All @@ -417,7 +426,8 @@ def sample(
"""
assert self.model is not None, "model is None! Cannot sample without a model!"
self.model.to(self.model_device)
return self.model.sample(x, num_samples=num_samples)
model = deepcopy(self.model) if self.copy_model else self.model
return model.sample(x, num_samples=num_samples)

def finish(self) -> None:
"""Finish the strategy."""
Expand Down Expand Up @@ -459,7 +469,8 @@ def finished(self) -> bool:
assert (
self.model is not None
), "model is None! Cannot predict without a model!"
fmean, _ = self.model.predict(self.eval_grid, probability_space=True)
model = deepcopy(self.model) if self.copy_model else self.model
fmean, _ = model.predict(self.eval_grid, probability_space=True)
meets_post_range = (
(fmean.max() - fmean.min()) >= self.min_post_range
).item()
Expand Down Expand Up @@ -534,9 +545,10 @@ def fit(self) -> None:
"""Fit the model."""
if self.can_fit:
self.model.to(self.model_device) # type: ignore
model = deepcopy(self.model) if self.copy_model else self.model
if self.keep_most_recent is not None:
try:
self.model.fit( # type: ignore
model.fit( # type: ignore
self.x[-self.keep_most_recent :], # type: ignore
self.y[-self.keep_most_recent :], # type: ignore
)
Expand All @@ -546,21 +558,23 @@ def fit(self) -> None:
)
else:
try:
self.model.fit(self.x, self.y) # type: ignore
model.fit(self.x, self.y) # type: ignore
except ModelFittingError:
logger.warning(
"Failed to fit model! Predictions may not be accurate!"
)
self.model = model
else:
warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning)

def update(self) -> None:
"""Update the model."""
if self.can_fit:
self.model.to(self.model_device) # type: ignore
model = deepcopy(self.model) if self.copy_model else self.model
if self.keep_most_recent is not None:
try:
self.model.update( # type: ignore
model.update( # type: ignore
self.x[-self.keep_most_recent :], # type: ignore
self.y[-self.keep_most_recent :], # type: ignore
)
Expand All @@ -570,11 +584,13 @@ def update(self) -> None:
)
else:
try:
self.model.update(self.x, self.y) # type: ignore
model.update(self.x, self.y) # type: ignore
except ModelFittingError:
logger.warning(
"Failed to fit model! Predictions may not be accurate!"
)

self.model = model
else:
warnings.warn("Cannot fit: no model has been initialized!", RuntimeWarning)

Expand Down Expand Up @@ -656,6 +672,8 @@ def from_config(cls, config: Config, name: str) -> Strategy:
)
min_asks = n_trials

copy_model = config.getboolean(name, "copy_model", fallback=False)

return cls(
lb=lb,
ub=ub,
Expand All @@ -673,5 +691,6 @@ def from_config(cls, config: Config, name: str) -> Strategy:
min_post_range=min_post_range,
keep_most_recent=keep_most_recent,
min_total_tells=min_total_tells,
copy_model=copy_model,
name=name,
)

0 comments on commit 075e40d

Please sign in to comment.