diff --git a/test/test_env.py b/test/test_env.py index 8a2642efb05..f4759f9a119 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -33,6 +33,7 @@ TensorDictBase, ) from tensordict.nn import TensorDictModuleBase +from tensordict.tensorclass import NonTensorStack from tensordict.utils import _unravel_key_to_tuple from torch import nn @@ -4577,12 +4578,13 @@ def __next__(self): ], ) @pytest.mark.parametrize("batched", [True, False]) + @pytest.mark.parametrize("batch_size", [0, 4]) @pytest.mark.parametrize("device", [None, "cpu"]) - def test_llm_env(self, str2str, batched, stack_method, device): + def test_llm_env(self, str2str, batched, stack_method, device, batch_size): env = LLMEnv(str2str=str2str, device=device) if str2str: primer = DataLoadingPrimer( - dataloader=self.DummyDataLoader(), + dataloader=self.DummyDataLoader(batch_size=batch_size), data_keys=["observation"], example_data="a string!", ) @@ -4590,7 +4592,9 @@ def test_llm_env(self, str2str, batched, stack_method, device): if stack_method is None: stack_method = as_padded_tensor primer = DataLoadingPrimer( - dataloader=self.DummyTensorDataLoader(padding=True), + dataloader=self.DummyTensorDataLoader( + batch_size=batch_size, padding=True + ), data_keys=["observation"], data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], stack_method=stack_method, @@ -4601,6 +4605,7 @@ def test_llm_env(self, str2str, batched, stack_method, device): if batched: td = env.reset(TensorDict(batch_size=[3])) env.check_env_specs(break_when_any_done="both", tensordict=td) + r = env.rollout(10, tensordict=TensorDict(batch_size=[3])) else: env.check_env_specs(break_when_any_done="both") @@ -4616,10 +4621,13 @@ def test_llm_env(self, str2str, batched, stack_method, device): ) @pytest.mark.parametrize("batched", [True, False]) @pytest.mark.parametrize("device", [None, "cpu"]) - def test_llm_from_dataloader(self, str2str, batched, stack_method, device): + @pytest.mark.parametrize("batch_size", [0, 4]) + def test_llm_from_dataloader( + self, str2str, batched, stack_method, device, batch_size + ): if str2str: kwargs = { - "dataloader": self.DummyDataLoader(), + "dataloader": self.DummyDataLoader(batch_size=batch_size), "data_keys": ["observation"], "example_data": "a string!", } @@ -4627,7 +4635,9 @@ def test_llm_from_dataloader(self, str2str, batched, stack_method, device): if stack_method is None: stack_method = as_padded_tensor kwargs = { - "dataloader": self.DummyTensorDataLoader(padding=True), + "dataloader": self.DummyTensorDataLoader( + padding=True, batch_size=batch_size + ), "data_keys": ["observation"], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, @@ -4640,6 +4650,55 @@ def test_llm_from_dataloader(self, str2str, batched, stack_method, device): env.check_env_specs(break_when_any_done="both", tensordict=td) else: env.check_env_specs(break_when_any_done="both") + if batch_size > 0: + + def policy(td): + if str2str: + if not td.shape: + td["action"] = "" + else: + td["action"] = NonTensorStack( + *["" for _ in range(td.shape[0])] + ) + else: + td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + return td + + if batched: + # Tell the env that we want 3 sub-envs + r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3])) + assert r.ndim == 2 + if str2str: + assert isinstance(r[0, 0]["observation"], str) + assert isinstance(r[0, 1]["observation"], str) + assert ( + r[0, 0]["observation"] + == r[0, 1]["observation"][: -len(r[0, 0]["action"])] + ) + assert ( + r[0, 1]["observation"] + == r[0, 2]["observation"][: -len(r[0, 1]["action"])] + ) + assert ( + r[-1, 0]["observation"] + == r[-1, 1]["observation"][: -len(r[-1, 0]["action"])] + ) + assert ( + r[-1, 1]["observation"] + == r[-1, 2]["observation"][: -len(r[-1, 1]["action"])] + ) + else: + assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all() + assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all() + assert ( + r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1] + ).all() + assert ( + r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1] + ).all() + else: + r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[])) + assert r.ndim == 1 if __name__ == "__main__": diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index dd70a8c2598..c62f2c015a2 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -188,6 +188,14 @@ def from_dataloader( ) return env.append_transform(primer) + @staticmethod + def _check_obs_act_and_cat(obs, action): + if not isinstance(obs, str): + raise TypeError(f"Observation must be a string, got {type(obs)}.") + if not isinstance(action, str): + raise TypeError(f"Action must be a string, got {type(action)}.") + return obs + action + def _step( self, tensordict: TensorDictBase, @@ -202,11 +210,14 @@ def _step( "The tensordict is batchless, yet the action and/or observations are not " f"strings but {type(action)} and {type(obs)}, respectivly." ) - observation = obs + action + observation = self._check_obs_act_and_cat(obs, action) else: - observation = [ - _obs + _action for (_obs, _action) in _zip_strict(obs, action) - ] + observation = NonTensorStack( + *[ + self._check_obs_act_and_cat(_obs, _action) + for (_obs, _action) in _zip_strict(obs, action) + ] + ) else: try: obs: torch.Tensor = tensordict.get(self.observation_key) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 752aaea573d..997b1af039a 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from collections import deque from collections.abc import Mapping from copy import copy, deepcopy from typing import Any, Callable, Iterable, Literal @@ -87,11 +88,21 @@ class DataLoadingPrimer(TensorDictPrimer): Args: dataloader (Iterable[Any]): The dataloader to load data from. + + Keyword Args: primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None. data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None. data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None. example_data (Any, optional): Example data to use for initializing the primer. Defaults to None. stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``. + use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size + that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data + ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset + are loaded in order. + Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1. + auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the + tensordict returned by the transform will be automatically determined assuming that there is a single batch + dimension. Attributes: dataloader (Iterable[Any]): The dataloader to load data from. @@ -339,14 +350,25 @@ class DataLoadingPrimer(TensorDictPrimer): def __init__( self, dataloader: Iterable[Any], + *, primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, example_data: Any = None, stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, + use_buffer: bool | None = None, + auto_batch_size: bool = True, ): self.dataloader = dataloader + if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None: + use_buffer = True + + self.use_buffer = use_buffer + # No auto_batch_size if we know we have a single element + self.auto_batch_size = auto_batch_size and ( + getattr(dataloader, "dataloader", 1) > 0 + ) self.endless_dataloader = self._endless_iter(self.dataloader) if primers is None: if data_keys is None: @@ -381,6 +403,8 @@ def __init__( single_default_value=True, call_before_env_reset=True, ) + if self.use_buffer: + self._queue = deque() @classmethod def _endless_iter(self, obj): @@ -388,6 +412,10 @@ def _endless_iter(self, obj): yield from obj def _load_from_dataloader(self, reset: torch.Tensor | None = None): + """Loads a single element from the dataloader, or alternatively from the buffer. + + If `reset` is passed, the one element per reset will be loaded. + """ if reset is not None: if not reset.any(): raise RuntimeError("reset must have at least one True value.") @@ -395,20 +423,35 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): return self.stack_method( [self._load_from_dataloader() for i in range(reset.sum())] ) + if self.use_buffer and len(self._queue) > 0: + return self._queue.popleft() data = next(self.endless_dataloader) # Some heuristic here: # if data is a map, assume its keys match the keys in spec # TODO: one could rename the keys too if isinstance(data, Mapping): - out = TensorDict(data) + out = TensorDict.from_dict( + data, auto_batch_size=self.auto_batch_size, batch_dims=1 + ) elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)): - out = TensorDict({k: val for k, val in _zip_strict(self.data_keys, data)}) + out = TensorDict.from_dict( + {k: val for k, val in _zip_strict(self.data_keys, data)}, + auto_batch_size=self.auto_batch_size, + batch_dims=1, + ) elif len(self.data_keys) == 1: - out = TensorDict({self.data_keys[0]: data}) + out = TensorDict.from_dict( + {self.data_keys[0]: data}, + auto_batch_size=self.auto_batch_size, + batch_dims=1, + ) else: raise ValueError( f"Unrecognized data type: {type(data)} with keys {self.data_keys}." ) + if self.use_buffer: + self._queue.extend(out.unbind(0)) + return self._queue.popleft() return out