Skip to content

Commit

Permalink
[AIR] Allow users to pass Callable[[torch.Tensor], torch.Tensor] to…
Browse files Browse the repository at this point in the history
… `TorchVisionTransform` (ray-project#32383)

Transforms like RandomHorizontalFlip expect Torch tensors as input, but if you're applying the transform per-epoch, then you can't use ToTensor. To fix the problem, this PR updates TorchVisionPreprocessor to convert ndarray inputs to Torch tensors.

You can't use ToTensor to convert the ndarrays to Torch tensors because then you'd be applying ToTensor twice, and your images would get scaled incorrectly.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored Feb 10, 2023
1 parent 613f4b0 commit faeb2cc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
53 changes: 36 additions & 17 deletions python/ray/data/preprocessors/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Union

import numpy as np
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray

from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

Expand All @@ -21,8 +21,9 @@ class TorchVisionPreprocessor(Preprocessor):
>>> dataset # doctest: +ellipsis
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(..., 3), dtype=float)})
:class:`TorchVisionPreprocessor` passes ndarrays to your transform. To convert
ndarrays to Torch tensors, add ``ToTensor`` to your pipeline.
Torch models expect inputs of shape :math:`(B, C, H, W)` in the range
:math:`[0.0, 1.0]`. To convert images to this format, add ``ToTensor`` to your
preprocessing pipeline.
>>> from torchvision import transforms
>>> from ray.data.preprocessors import TorchVisionPreprocessor
Expand Down Expand Up @@ -57,7 +58,8 @@ class TorchVisionPreprocessor(Preprocessor):
Args:
columns: The columns to apply the TorchVision transform to.
transform: The TorchVision transform you want to apply. This transform should
accept an ``np.ndarray`` as input and return a ``torch.Tensor`` as output.
accept a ``np.ndarray`` or ``torch.Tensor`` as input and return a
``torch.Tensor`` as output.
batched: If ``True``, apply ``transform`` to batches of shape
:math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images.
""" # noqa: E501
Expand All @@ -67,37 +69,54 @@ class TorchVisionPreprocessor(Preprocessor):
def __init__(
self,
columns: List[str],
transform: Callable[["np.ndarray"], "torch.Tensor"],
transform: Callable[[Union["np.ndarray", "torch.Tensor"]], "torch.Tensor"],
batched: bool = False,
):
self._columns = columns
self._fn = transform
self._torchvision_transform = transform
self._batched = batched

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(columns={self._columns}, "
f"transform={self._fn!r})"
f"transform={self._torchvision_transform!r})"
)

def _transform_numpy(
self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
def transform(batch: np.ndarray) -> np.ndarray:
import torch
from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor

def apply_torchvision_transform(array: np.ndarray) -> np.ndarray:
try:
tensor = convert_ndarray_to_torch_tensor(array)
output = self._torchvision_transform(tensor)
except TypeError:
# Transforms like `ToTensor` expect a `np.ndarray` as input.
output = self._torchvision_transform(array)

if not isinstance(output, torch.Tensor):
raise ValueError(
"`TorchVisionPreprocessor` expected your transform to return a "
"`torch.Tensor`, but your transform returned a "
f"`{type(output).__name__}` instead."
)

return output.numpy()

def transform_batch(batch: np.ndarray) -> np.ndarray:
if self._batched:
return self._fn(batch).numpy()
return apply_torchvision_transform(batch)
return _create_possibly_ragged_ndarray(
[self._fn(array).numpy() for array in batch],
[apply_torchvision_transform(array) for array in batch]
)

if isinstance(np_data, dict):
outputs = {}
for column, batch in np_data.items():
if column in self._columns:
outputs[column] = transform(batch)
else:
outputs[column] = batch
outputs = np_data
for column in self._columns:
outputs[column] = transform_batch(np_data[column])
else:
outputs = transform(np_data)
outputs = transform_batch(np_data)

return outputs
25 changes: 23 additions & 2 deletions python/ray/data/tests/preprocessors/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@ def __repr__(self):
== "TorchVisionPreprocessor(columns=['spam'], transform=StubTransform())"
)

def test_transform_images(self):
@pytest.mark.parametrize(
"transform",
[
transforms.ToTensor(), # `ToTensor` accepts an `np.ndarray` as input
transforms.Lambda(lambda tensor: tensor.permute(2, 0, 1)),
],
)
def test_transform_images(self, transform):
dataset = ray.data.from_items(
[
{"image": np.zeros((32, 32, 3)), "label": 0},
{"image": np.zeros((32, 32, 3)), "label": 1},
]
)
transform = transforms.ToTensor()
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)

transformed_dataset = preprocessor.transform(dataset)
Expand Down Expand Up @@ -99,6 +105,21 @@ def test_transform_ragged_images(self):
labels = {record["label"] for record in transformed_dataset.take_all()}
assert labels == {0, 1}

def test_invalid_transform_raises_value_error(self):
dataset = ray.data.from_items(
[
{"image": np.zeros((32, 32, 3)), "label": 0},
{"image": np.zeros((32, 32, 3)), "label": 1},
]
)
# `TorchVisionPreprocessor` expects transforms to return `torch.Tensor`s, but
# this `transform` returns a `np.ndarray`.
transform = transforms.Lambda(lambda tensor: tensor.numpy())
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)

with pytest.raises(ValueError):
preprocessor.transform(dataset)


if __name__ == "__main__":
import sys
Expand Down

0 comments on commit faeb2cc

Please sign in to comment.