Skip to content

Commit

Permalink
[Feature] DataLoadingPrimer handling of dataloader with batch-size > 0
Browse files Browse the repository at this point in the history
ghstack-source-id: cf1942ece8dfbd6506f91939561df7443bd840ab
Pull Request resolved: #2821
  • Loading branch information
vmoens committed Mar 3, 2025
1 parent 40b147e commit f4713f9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 13 deletions.
71 changes: 65 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TensorDictBase,
)
from tensordict.nn import TensorDictModuleBase
from tensordict.tensorclass import NonTensorStack
from tensordict.utils import _unravel_key_to_tuple
from torch import nn

Expand Down Expand Up @@ -4577,20 +4578,23 @@ def __next__(self):
],
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("batch_size", [0, 4])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_env(self, str2str, batched, stack_method, device):
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
env = LLMEnv(str2str=str2str, device=device)
if str2str:
primer = DataLoadingPrimer(
dataloader=self.DummyDataLoader(),
dataloader=self.DummyDataLoader(batch_size=batch_size),
data_keys=["observation"],
example_data="a string!",
)
else:
if stack_method is None:
stack_method = as_padded_tensor
primer = DataLoadingPrimer(
dataloader=self.DummyTensorDataLoader(padding=True),
dataloader=self.DummyTensorDataLoader(
batch_size=batch_size, padding=True
),
data_keys=["observation"],
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
stack_method=stack_method,
Expand All @@ -4601,6 +4605,7 @@ def test_llm_env(self, str2str, batched, stack_method, device):
if batched:
td = env.reset(TensorDict(batch_size=[3]))
env.check_env_specs(break_when_any_done="both", tensordict=td)
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
else:
env.check_env_specs(break_when_any_done="both")

Expand All @@ -4616,18 +4621,23 @@ def test_llm_env(self, str2str, batched, stack_method, device):
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
@pytest.mark.parametrize("batch_size", [0, 4])
def test_llm_from_dataloader(
self, str2str, batched, stack_method, device, batch_size
):
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(),
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"example_data": "a string!",
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(padding=True),
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
Expand All @@ -4640,6 +4650,55 @@ def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
env.check_env_specs(break_when_any_done="both", tensordict=td)
else:
env.check_env_specs(break_when_any_done="both")
if batch_size > 0:

def policy(td):
if str2str:
if not td.shape:
td["action"] = "<nothing>"
else:
td["action"] = NonTensorStack(
*["<nothing>" for _ in range(td.shape[0])]
)
else:
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
return td

if batched:
# Tell the env that we want 3 sub-envs
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
assert r.ndim == 2
if str2str:
assert isinstance(r[0, 0]["observation"], str)
assert isinstance(r[0, 1]["observation"], str)
assert (
r[0, 0]["observation"]
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
)
assert (
r[0, 1]["observation"]
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
)
assert (
r[-1, 0]["observation"]
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
)
assert (
r[-1, 1]["observation"]
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
)
else:
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
assert (
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
).all()
assert (
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
).all()
else:
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
assert r.ndim == 1


if __name__ == "__main__":
Expand Down
19 changes: 15 additions & 4 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def from_dataloader(
)
return env.append_transform(primer)

@staticmethod
def _check_obs_act_and_cat(obs, action):
if not isinstance(obs, str):
raise TypeError(f"Observation must be a string, got {type(obs)}.")
if not isinstance(action, str):
raise TypeError(f"Action must be a string, got {type(action)}.")
return obs + action

def _step(
self,
tensordict: TensorDictBase,
Expand All @@ -202,11 +210,14 @@ def _step(
"The tensordict is batchless, yet the action and/or observations are not "
f"strings but {type(action)} and {type(obs)}, respectivly."
)
observation = obs + action
observation = self._check_obs_act_and_cat(obs, action)
else:
observation = [
_obs + _action for (_obs, _action) in _zip_strict(obs, action)
]
observation = NonTensorStack(
*[
self._check_obs_act_and_cat(_obs, _action)
for (_obs, _action) in _zip_strict(obs, action)
]
)
else:
try:
obs: torch.Tensor = tensordict.get(self.observation_key)
Expand Down
49 changes: 46 additions & 3 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections import deque
from collections.abc import Mapping
from copy import copy, deepcopy
from typing import Any, Callable, Iterable, Literal
Expand Down Expand Up @@ -87,11 +88,21 @@ class DataLoadingPrimer(TensorDictPrimer):
Args:
dataloader (Iterable[Any]): The dataloader to load data from.
Keyword Args:
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size
that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data
ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset
are loaded in order.
Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1.
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
tensordict returned by the transform will be automatically determined assuming that there is a single batch
dimension.
Attributes:
dataloader (Iterable[Any]): The dataloader to load data from.
Expand Down Expand Up @@ -339,14 +350,25 @@ class DataLoadingPrimer(TensorDictPrimer):
def __init__(
self,
dataloader: Iterable[Any],
*,
primers: Composite | None = None,
data_keys: list[NestedKey] | None = None,
data_specs: list[TensorSpec] | None = None,
example_data: Any = None,
stack_method: Callable[[Any], Any]
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
use_buffer: bool | None = None,
auto_batch_size: bool = True,
):
self.dataloader = dataloader
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
use_buffer = True

self.use_buffer = use_buffer
# No auto_batch_size if we know we have a single element
self.auto_batch_size = auto_batch_size and (
getattr(dataloader, "dataloader", 1) > 0
)
self.endless_dataloader = self._endless_iter(self.dataloader)
if primers is None:
if data_keys is None:
Expand Down Expand Up @@ -381,34 +403,55 @@ def __init__(
single_default_value=True,
call_before_env_reset=True,
)
if self.use_buffer:
self._queue = deque()

@classmethod
def _endless_iter(self, obj):
while True:
yield from obj

def _load_from_dataloader(self, reset: torch.Tensor | None = None):
"""Loads a single element from the dataloader, or alternatively from the buffer.
If `reset` is passed, the one element per reset will be loaded.
"""
if reset is not None:
if not reset.any():
raise RuntimeError("reset must have at least one True value.")
if reset.ndim > 0:
return self.stack_method(
[self._load_from_dataloader() for i in range(reset.sum())]
)
if self.use_buffer and len(self._queue) > 0:
return self._queue.popleft()
data = next(self.endless_dataloader)
# Some heuristic here:
# if data is a map, assume its keys match the keys in spec
# TODO: one could rename the keys too
if isinstance(data, Mapping):
out = TensorDict(data)
out = TensorDict.from_dict(
data, auto_batch_size=self.auto_batch_size, batch_dims=1
)
elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)):
out = TensorDict({k: val for k, val in _zip_strict(self.data_keys, data)})
out = TensorDict.from_dict(
{k: val for k, val in _zip_strict(self.data_keys, data)},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
elif len(self.data_keys) == 1:
out = TensorDict({self.data_keys[0]: data})
out = TensorDict.from_dict(
{self.data_keys[0]: data},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
else:
raise ValueError(
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
)
if self.use_buffer:
self._queue.extend(out.unbind(0))
return self._queue.popleft()
return out


Expand Down

0 comments on commit f4713f9

Please sign in to comment.