Skip to content

Commit

Permalink
fix StratifiedStandardize dtype/nan issue (#2757)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2757

If dtype passed to `get_task_value_remapping` is not float or double, an exception is raised because NaN cannot be used in an int/long tensor. The docstring states this, but we did it anyway in StratifiedStandardize, which meant that remapping didn't work. This fixes the issue.

Reviewed By: esantorella

Differential Revision: D70111086

fbshipit-source-id: c01044dadd2de33c84a346a3ee500dcf504cfe23
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 25, 2025
1 parent 9a7c517 commit 0be800e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 9 deletions.
6 changes: 3 additions & 3 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __init__(
OutcomeTransform.__init__(self)
self._stratification_idx = stratification_idx
task_values = task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
if self.strata_mapping is None:
self.strata_mapping = task_values
n_strata = self.strata_mapping.shape[0]
Expand Down Expand Up @@ -576,7 +576,7 @@ def forward(
strata = X[..., self._stratification_idx].long()
unique_strata = strata.unique()
for s in unique_strata:
mapped_strata = self.strata_mapping[s]
mapped_strata = self.strata_mapping[s].long()
mask = strata != s
Y_strata = Y.clone()
Y_strata[..., mask, :] = float("nan")
Expand Down Expand Up @@ -616,7 +616,7 @@ def _get_per_input_means_stdvs(
- The per-input stdvs squared.
"""
strata = X[..., self._stratification_idx].long()
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
mapped_strata = self.strata_mapping[strata].unsqueeze(-1).long()
# get means and stdvs for each strata
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
Expand Down
2 changes: 2 additions & 0 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
return value will be `None`, when the task values are contiguous
integers starting from zero.
"""
if dtype not in (torch.float, torch.double):
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
task_range = torch.arange(
len(task_values), dtype=task_values.dtype, device=task_values.device
)
Expand Down
9 changes: 9 additions & 0 deletions test/models/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,12 @@ def test_get_task_value_remapping(self) -> None:
mapping = get_task_value_remapping(task_values, dtype)
self.assertTrue(torch.equal(mapping[[1, 3]], expected_mapping_no_nan))
self.assertTrue(torch.isnan(mapping[[0, 2]]).all())

def test_get_task_value_remapping_invalid_dtype(self) -> None:
task_values = torch.tensor([1, 3])
for dtype in (torch.int32, torch.long, torch.bool):
with self.assertRaisesRegex(
ValueError,
f"dtype must be torch.float or torch.double, but got {dtype}.",
):
get_task_value_remapping(task_values, dtype)
22 changes: 16 additions & 6 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,24 @@ def test_stratified_standardize(self):
n = 5
seed = randint(0, 100)
torch.manual_seed(seed)
for dtype, batch_shape in itertools.product(
(torch.float, torch.double), (torch.Size([]), torch.Size([3]))
for dtype, batch_shape, task_values in itertools.product(
(torch.float, torch.double),
(torch.Size([]), torch.Size([3])),
(
torch.tensor([0, 1], dtype=torch.long, device=self.device),
torch.tensor([0, 3], dtype=torch.long, device=self.device),
),
):
torch.manual_seed(seed)
tval = task_values[1].item()
X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device)
X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device)
X[..., -1] = torch.tensor(
[0, tval, 0, tval, 0], dtype=dtype, device=self.device
)
Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device)
Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device)
strata_tf = StratifiedStandardize(
task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device),
task_values=task_values,
stratification_idx=-1,
batch_shape=batch_shape,
)
Expand All @@ -400,9 +408,11 @@ def test_stratified_standardize(self):
tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1)
# check that stratified means are expected
self.assertAllClose(strata_tf.means[..., :1, :], tf0.means)
self.assertAllClose(strata_tf.means[..., 1:, :], tf1.means)
# use remapped task values to index
self.assertAllClose(strata_tf.means[..., 1:2, :], tf1.means)
self.assertAllClose(strata_tf.stdvs[..., :1, :], tf0.stdvs)
self.assertAllClose(strata_tf.stdvs[..., 1:, :], tf1.stdvs)
# use remapped task values to index
self.assertAllClose(strata_tf.stdvs[..., 1:2, :], tf1.stdvs)
# check the transformed values
self.assertAllClose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
self.assertAllClose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))
Expand Down

0 comments on commit 0be800e

Please sign in to comment.