Skip to content

Commit

Permalink
InputTransfrom list broadcasted over batch shapes
Browse files Browse the repository at this point in the history
Summary: This commit adds `BatchBroadcastedTransformList`, which

Differential Revision: D63660807
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Sep 30, 2024
1 parent e29e30a commit edf1776
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 1 deletion.
96 changes: 95 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Optional, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -155,6 +155,100 @@ def preprocess_transform(self, X: Tensor) -> Tensor:
return X


class BatchBroadcastedTransformList(InputTransform, ModuleDict):
r"""An input transform representing a list of transforms to be broadcasted."""

def __init__(
self,
transforms: List[InputTransform],
) -> None:
r"""A transform list that is broadcasted across the input's first dimension.
This is allows using a batched Gaussian process model in cases where the input
transforms for each batch dimensions are different.
Args:
transforms: The transforms to broadcast across the first batch dimension.
The transform at position i in the list will be applied to `X[i]` for
a given input tensor `X` in the forward pass.
Example:
>>> tf1 = Normalize(d=2)
>>> tf2 = InputStandardize(d=2)
>>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2])
"""
super().__init__()
self.transform_on_train = False
self.transform_on_eval = False
self.transform_on_fantasize = False
self.transforms = transforms
self.is_one_to_many = self.transforms[0].is_one_to_many
if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms):
raise ValueError( # output shapes of transforms must be the same
"All transforms must have the same is_one_to_many property."
)
for tf in self.transforms:
self.transform_on_train |= tf.transform_on_train
self.transform_on_eval |= tf.transform_on_eval
self.transform_on_fantasize |= tf.transform_on_fantasize

def transform(self, X: Tensor) -> Tensor:
r"""Transform the inputs to a model.
Individual transforms are applied in sequence.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
return torch.stack([tf.forward(Xi) for Xi, tf in zip(X, self.transforms)])

def untransform(self, X: Tensor) -> Tensor:
r"""Un-transform the inputs to a model.
Un-transforms of the individual transforms are applied in reverse sequence.
Args:
X: A `batch_shape x n x d`-dim tensor of transformed inputs.
Returns:
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
"""
return torch.stack([tf.untransform(Xi) for Xi, tf in zip(X, self.transforms)])

def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.
Args:
other: Another input transform.
Returns:
A boolean indicating if the other transform is equivalent.
"""
return super().equals(other=other) and all(
t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms)
)

def preprocess_transform(self, X: Tensor) -> Tensor:
r"""Apply transforms for preprocessing inputs.
The main use cases for this method are 1) to preprocess training data
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
acquisition functions so that `X_baseline` is "preprocessed" with the
same transformations as the cached training inputs.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of (transformed) inputs.
"""
return torch.stack(
[tf.preprocess_transform(Xi) for Xi, tf in zip(X, self.transforms)]
)


class ChainedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing the chaining of individual transforms."""

Expand Down
91 changes: 91 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from botorch.models.transforms.input import (
AffineInputTransform,
AppendFeatures,
BatchBroadcastedTransformList,
ChainedInputTransform,
FilterFeatures,
InputPerturbation,
Expand Down Expand Up @@ -652,6 +653,96 @@ def test_chained_input_transform(self) -> None:
tf = ChainedInputTransform(stz=tf1, pert=tf2)
self.assertTrue(tf.is_one_to_many)

def test_batch_broadcasted_transform_list(self) -> None:
ds = (1, 2)
batch_shapes = [torch.Size([2])]
dtypes = (torch.float, torch.double)
# set seed to range where this is known to not be flaky
torch.manual_seed(randint(0, 1000))

for d, batch_shape, dtype in itertools.product(ds, batch_shapes, dtypes):
bounds = torch.tensor(
[[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype
)
tf1 = Normalize(d=d, bounds=bounds, batch_shape=torch.Size([]))
tf2 = InputStandardize(d=d, batch_shape=torch.Size([]))
transforms = [tf1, tf2]
tf = BatchBroadcastedTransformList(transforms=transforms)
# make copies for validation below
transforms_ = [deepcopy(tf_i) for tf_i in transforms]
self.assertTrue(tf.training)
# self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"])
self.assertEqual(tf.transforms[0], tf1)
self.assertEqual(tf.transforms[1], tf2)
self.assertFalse(tf.is_one_to_many)

X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype)
X_tf = tf(X)
X_tf_ = torch.stack([tf_i_(Xi) for tf_i_, Xi in zip(transforms_, X)], dim=0)
self.assertTrue(tf1.training)
self.assertTrue(tf2.training)
self.assertTrue(torch.equal(X_tf, X_tf_))
X_utf = tf.untransform(X_tf)
self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4)

# test not transformed on eval
for tf_i in transforms:
tf_i.transform_on_eval = False

tf = BatchBroadcastedTransformList(transforms=transforms)
tf.eval()
self.assertTrue(torch.equal(tf(X), X))

# test transformed on eval
for tf_i in transforms:
tf_i.transform_on_eval = True

tf = BatchBroadcastedTransformList(transforms=transforms)
tf.eval()
self.assertTrue(torch.equal(tf(X), X_tf))

# test not transformed on train
for tf_i in transforms:
tf_i.transform_on_train = False

tf = BatchBroadcastedTransformList(transforms=transforms)
tf.train()
self.assertTrue(torch.equal(tf(X), X))

# test __eq__
other_tf = BatchBroadcastedTransformList(transforms=transforms)
self.assertTrue(tf.equals(other_tf))
# change order
other_tf = BatchBroadcastedTransformList(
transforms=list(reversed(transforms))
)
self.assertFalse(tf.equals(other_tf))
# Identical transforms but different objects.
other_tf = BatchBroadcastedTransformList(transforms=deepcopy(transforms))
self.assertTrue(tf.equals(other_tf))

# test preprocess_transform
transforms[-1].transform_on_train = False
transforms[0].transform_on_train = True
tf = BatchBroadcastedTransformList(transforms=transforms)
self.assertTrue(
torch.equal(
tf.preprocess_transform(X)[0], transforms[0].transform(X[0])
)
)

# test one-to-many
tf2 = InputPerturbation(perturbation_set=2 * bounds)
with self.assertRaisesRegex(ValueError, r".*one_to_many.*"):
tf = BatchBroadcastedTransformList(transforms=[tf1, tf2])

# these could technically be batched internally, but we're testing the generic
# batch broadcasted transform list here. Could change test to use AppendFeatures
tf1 = InputPerturbation(perturbation_set=bounds)
tf2 = InputPerturbation(perturbation_set=2 * bounds)
tf = BatchBroadcastedTransformList(transforms=[tf1, tf2])
self.assertTrue(tf.is_one_to_many)

def test_round_transform_init(self) -> None:
# basic init
int_idcs = [0, 4]
Expand Down

0 comments on commit edf1776

Please sign in to comment.