From 5dbbd2665654e66024ecd858d25b9a6bf4c5bc97 Mon Sep 17 00:00:00 2001 From: Dimitrios Tsaras Date: Thu, 19 Dec 2024 20:15:55 +0800 Subject: [PATCH 1/3] Added the necessary transforms for Hindsight Experience Replay --- torchrl/envs/transforms/transforms.py | 163 ++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..d9200208843 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -40,6 +40,7 @@ TensorDictBase, unravel_key, unravel_key_list, + pad_sequence, ) from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( @@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: high=torch.iinfo(torch.int64).max, ) return super().transform_observation_spec(observation_spec) + + +class HERSubGoalSampler(Transform): + """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. + Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. + Args: + num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. + out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". + """ + def __init__( + self, + num_samples: int = 4, + subgoal_idx_key: str = "subgoal_idx", + strategy: str = "future" + ): + super().__init__( + in_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ) + self.num_samples = num_samples + self.subgoal_idx_key = subgoal_idx_key + self.strategy = strategy + + def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + if len(trajectories.shape) == 1: + trajectories = trajectories.unsqueeze(0) + + batch_size, trajectory_len = trajectories.shape + + if self.strategy == "last": + return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) + + else: + subgoal_idxs = [] + for i in range(batch_size): + subgoal_idxs.append( + TensorDict( + {"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, + batch_size=torch.Size(), + ) + ) + return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True) + + +class HERSubGoalAssigner(Transform): + """This module assigns the subgoal to the trajectory according to a given subgoal index. + Args: + subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". + subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". + """ + def __init__( + self, + achieved_goal_key: str = "achieved_goal", + desired_goal_key: str = "desired_goal", + ): + self.achieved_goal_key = achieved_goal_key + self.desired_goal_key = desired_goal_key + + def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: + batch_size, trajectory_len = trajectories.shape + for i in range(batch_size): + subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] + desired_goal_shape = trajectories[i][self.desired_goal_key].shape + trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) + trajectories[i][subgoals_idxs[i]]["next", "done"] = True + # trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True + + return trajectories + + +class HERRewardTransform(Transform): + """This module assigns the reward to the trajectory according to the new subgoal. + Args: + reward_name (str): The key to the reward. Defaults to "reward". + """ + def __init__( + self + ): + pass + + def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + return trajectories + + +class HindsightExperienceReplayTransform(Transform): + """Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. + This module is a wrapper that includes the following modules: + - SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. + - SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. + - RewardTransform: Assigns the reward to the trajectory according to the new subgoal. + Args: + SubGoalSampler (Transform): + SubGoalAssigner (Transform): + RewardTransform (Transform): + """ + def __init__( + self, + SubGoalSampler: Transform = HERSubGoalSampler(), + SubGoalAssigner: Transform = HERSubGoalAssigner(), + RewardTransform: Transform = HERRewardTransform(), + assign_subgoal_idxs: bool = False, + ): + super().__init__( + in_keys=None, + in_keys_inv=None, + out_keys_inv=None, + ) + self.SubGoalSampler = SubGoalSampler + self.SubGoalAssigner = SubGoalAssigner + self.RewardTransform = RewardTransform + self.assign_subgoal_idxs = assign_subgoal_idxs + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + augmentation_td = self.her_augmentation(tensordict) + return torch.cat([tensordict, augmentation_td], dim=0) + + def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: + return self.her_augmentation(tensordict) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + return tensordict + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + raise ValueError(self.ENV_ERR) + + def her_augmentation(self, trajectories: TensorDictBase): + if len(trajectories.shape) == 1: + trajectories = trajectories.unsqueeze(0) + batch_size, trajectory_length = trajectories.shape + + new_trajectories = trajectories.clone(True) + + # Sample subgoal indices + subgoal_idxs = self.SubGoalSampler(new_trajectories) + + # Create new trajectories + augmented_trajectories = [] + list_idxs = [] + for i in range(batch_size): + idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] + + if "masks" in subgoal_idxs.keys(): + idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] + + list_idxs.append(idxs.unsqueeze(-1)) + new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) + + if self.assign_subgoal_idxs: + new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) + + augmented_trajectories.append(new_traj) + augmented_trajectories = torch.cat(augmented_trajectories, dim=0) + associated_idxs = torch.cat(list_idxs, dim=0) + + # Assign subgoals to the new trajectories + augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) + + # Adjust the rewards based on the new subgoals + augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) + + return augmented_trajectories From d40527eeb76d3b17fecefe6d511eded7f913da36 Mon Sep 17 00:00:00 2001 From: Dimitrios Tsaras Date: Wed, 15 Jan 2025 18:02:15 +0800 Subject: [PATCH 2/3] Updated HERRewardTransform and added a test case --- test/test_transforms.py | 100 ++++++++++++++++++++++ torchrl/envs/__init__.py | 4 + torchrl/envs/transforms/__init__.py | 4 + torchrl/envs/transforms/transforms.py | 116 ++++++++++++++++++-------- 4 files changed, 187 insertions(+), 37 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index cc3ca40b059..52581656f42 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -116,6 +116,10 @@ FrameSkipTransform, GrayScale, gSDENoise, + HERRewardTransform, + HERSubGoalAssigner, + HERSubGoalSampler, + HindsightExperienceReplayTransform, InitTracker, MultiStepTransform, NoopResetEnv, @@ -12376,6 +12380,102 @@ def test_transform_inverse(self): pytest.skip("Tested elsewhere") +class TestHERTransform(TransformBase): + @pytest.mark.parametrize("strategy", ["last", "future"]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_inverse(self, strategy, device): + batch = 10 + trajectory_len = 20 + num_samples = 4 + batch_size = [batch, trajectory_len] + torch.manual_seed(0) + + # Let every episode be a random 1D trajectory + velocity = torch.rand((batch, 1), device=device) + time = torch.arange(trajectory_len + 1, device=device).expand(batch, -1) + start_pos = torch.rand((batch, 1), device=device) + pos = start_pos + velocity * time + goal = ( + (torch.rand(batch, device=device) * 10) + .expand(trajectory_len, batch) + .T[:, :, None] + ) + + her = HindsightExperienceReplayTransform( + SubGoalSampler=HERSubGoalSampler( + num_samples=4, + strategy=strategy, + ), + SubGoalAssigner=HERSubGoalAssigner( + achieved_goal_key=("next", "pos"), + desired_goal_key="original_goal", + ), + RewardTransform=HERRewardTransform(), + ) + + done = torch.zeros(*batch_size, 1, dtype=torch.bool, device=device) + done[:, -1] = True + reward = done.clone().float() + + td = TensorDict( + { + "pos": pos[:, :-1], + "original_goal": goal, + "next": { + "done": done, + "reward": reward, + "pos": pos[:, 1:], + "original_goal": goal, + }, + }, + batch_size, + device=device, + ) + + td = her.inv(td) + if strategy == "last": + assert td.shape == (batch * 2, trajectory_len) + elif strategy == "future": + assert td.shape == (batch * (num_samples + 1), trajectory_len) + + # original trajectories are at the top so we can check if the sugoals are part of the positions + augmented_td = td[batch:, :] + new_batch_size, _ = augmented_td.shape + for i in range(new_batch_size): + goal_value = augmented_td["original_goal"][i, 0] + assert (goal_value == augmented_td["next", "pos"][i]).any() + + def test_parallel_trans_env_check(self): + pass + + def test_serial_trans_env_check(self): + pass + + def test_single_trans_env_check(self): + pass + + def test_trans_parallel_env_check(self): + pass + + def test_trans_serial_env_check(self): + pass + + def test_transform_compose(self): + pass + + def test_transform_env(self): + pass + + def test_transform_model(self): + pass + + def test_transform_no_env(self): + pass + + def test_transform_rb(self): + pass + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b863ad0801c..d096a618e3c 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,6 +67,10 @@ FrameSkipTransform, GrayScale, gSDENoise, + HERRewardTransform, + HERSubGoalAssigner, + HERSubGoalSampler, + HindsightExperienceReplayTransform, InitTracker, KLRewardTransform, MultiStepTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 77f6ecc03bf..d7075d6e30f 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,6 +31,10 @@ FrameSkipTransform, GrayScale, gSDENoise, + HERRewardTransform, + HERSubGoalAssigner, + HERSubGoalSampler, + HindsightExperienceReplayTransform, InitTracker, NoopResetEnv, ObservationNorm, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d9200208843..f81562d7f8a 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -34,13 +34,14 @@ from tensordict import ( is_tensor_collection, LazyStackedTensorDict, + NestedKey, NonTensorData, + pad_sequence, set_lazy_legacy, TensorDict, TensorDictBase, unravel_key, unravel_key_list, - pad_sequence, ) from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( @@ -9268,17 +9269,18 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: class HERSubGoalSampler(Transform): - """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. - Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. + """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] representing the subgoal index. Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. + Args: num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". - """ + """ + def __init__( self, num_samples: int = 4, subgoal_idx_key: str = "subgoal_idx", - strategy: str = "future" + strategy: str = "future", ): super().__init__( in_keys=None, @@ -9296,14 +9298,20 @@ def forward(self, trajectories: TensorDictBase) -> TensorDictBase: batch_size, trajectory_len = trajectories.shape if self.strategy == "last": - return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) + return TensorDict( + {"subgoal_idx": torch.full((batch_size, 1), -2)}, batch_size=batch_size + ) else: subgoal_idxs = [] - for i in range(batch_size): + for _ in range(batch_size): subgoal_idxs.append( TensorDict( - {"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, + { + "subgoal_idx": (torch.randperm(trajectory_len - 2) + 1)[ + : self.num_samples + ] + }, batch_size=torch.Size(), ) ) @@ -9312,55 +9320,79 @@ def forward(self, trajectories: TensorDictBase) -> TensorDictBase: class HERSubGoalAssigner(Transform): """This module assigns the subgoal to the trajectory according to a given subgoal index. + Args: subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". - """ + """ + def __init__( self, - achieved_goal_key: str = "achieved_goal", - desired_goal_key: str = "desired_goal", + achieved_goal_key: str | tuple = "achieved_goal", + desired_goal_key: str | tuple = "desired_goal", ): self.achieved_goal_key = achieved_goal_key self.desired_goal_key = desired_goal_key - def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: + def forward( + self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor + ) -> TensorDictBase: batch_size, trajectory_len = trajectories.shape for i in range(batch_size): + # Assign the subgoal to the desired_goal_key, and ("next", desired_goal_key) of the trajectory subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] desired_goal_shape = trajectories[i][self.desired_goal_key].shape - trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) - trajectories[i][subgoals_idxs[i]]["next", "done"] = True - # trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True - + trajectories[i].set_( + self.desired_goal_key, subgoal.expand(desired_goal_shape) + ) + trajectories[i].set_( + ("next", self.desired_goal_key), subgoal.expand(desired_goal_shape) + ) + + # Update the done and (next, done) flags + new_done = torch.zeros_like( + trajectories[i]["next", "done"], dtype=torch.bool + ) + new_done[subgoals_idxs[i]] = True + trajectories[i].set_(("next", "done"), new_done) + return trajectories class HERRewardTransform(Transform): - """This module assigns the reward to the trajectory according to the new subgoal. + """This module assigns a reward of `reward_value` where the new trajectory `(next, done)` is `True`. + Args: - reward_name (str): The key to the reward. Defaults to "reward". - """ + reward_value (float): The reward to be assigned to the newly generated trajectories. Defaults to "1.0". + """ + def __init__( - self + self, + reward_value: float = 1.0, ): - pass - + self.reward_value = reward_value + def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + new_reward = torch.zeros_like(trajectories["next", "reward"]) + new_reward[trajectories["next", "done"]] = self.reward_value + trajectories.set_(("next", "reward"), new_reward) return trajectories class HindsightExperienceReplayTransform(Transform): """Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. + This module is a wrapper that includes the following modules: - SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. - SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. - RewardTransform: Assigns the reward to the trajectory according to the new subgoal. + Args: - SubGoalSampler (Transform): - SubGoalAssigner (Transform): - RewardTransform (Transform): - """ + SubGoalSampler (Transform): + SubGoalAssigner (Transform): + RewardTransform (Transform): + """ + def __init__( self, SubGoalSampler: Transform = HERSubGoalSampler(), @@ -9395,33 +9427,43 @@ def her_augmentation(self, trajectories: TensorDictBase): if len(trajectories.shape) == 1: trajectories = trajectories.unsqueeze(0) batch_size, trajectory_length = trajectories.shape - + new_trajectories = trajectories.clone(True) - + # Sample subgoal indices subgoal_idxs = self.SubGoalSampler(new_trajectories) - # Create new trajectories + # Create new trajectories augmented_trajectories = [] list_idxs = [] for i in range(batch_size): idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] - + if "masks" in subgoal_idxs.keys(): - idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] - + idxs = idxs[ + subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key] + ] + list_idxs.append(idxs.unsqueeze(-1)) - new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) - + new_traj = ( + new_trajectories[i] + .expand((idxs.numel(), trajectory_length)) + .clone(True) + ) + if self.assign_subgoal_idxs: - new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) - + new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze( + -1 + ).repeat(1, trajectory_length) + augmented_trajectories.append(new_traj) augmented_trajectories = torch.cat(augmented_trajectories, dim=0) associated_idxs = torch.cat(list_idxs, dim=0) # Assign subgoals to the new trajectories - augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) + augmented_trajectories = self.SubGoalAssigner.forward( + augmented_trajectories, associated_idxs + ) # Adjust the rewards based on the new subgoals augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) From fb2c7f83d95daa093de38a087a70c87aed5f7c06 Mon Sep 17 00:00:00 2001 From: Dimitrios Tsaras Date: Wed, 15 Jan 2025 19:31:22 +0800 Subject: [PATCH 3/3] Address the comments I missed --- test/test_transforms.py | 10 ++-- torchrl/envs/__init__.py | 2 +- torchrl/envs/transforms/__init__.py | 2 +- torchrl/envs/transforms/transforms.py | 79 +++++++++++++++++++-------- 4 files changed, 63 insertions(+), 30 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 52581656f42..6ca690638df 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -116,7 +116,7 @@ FrameSkipTransform, GrayScale, gSDENoise, - HERRewardTransform, + HERRewardAssigner, HERSubGoalAssigner, HERSubGoalSampler, HindsightExperienceReplayTransform, @@ -12381,7 +12381,7 @@ def test_transform_inverse(self): class TestHERTransform(TransformBase): - @pytest.mark.parametrize("strategy", ["last", "future"]) + @pytest.mark.parametrize("strategy", ["final", "future"]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_inverse(self, strategy, device): batch = 10 @@ -12402,15 +12402,15 @@ def test_transform_inverse(self, strategy, device): ) her = HindsightExperienceReplayTransform( - SubGoalSampler=HERSubGoalSampler( + subgoal_sampler=HERSubGoalSampler( num_samples=4, strategy=strategy, ), - SubGoalAssigner=HERSubGoalAssigner( + subgoal_assigner=HERSubGoalAssigner( achieved_goal_key=("next", "pos"), desired_goal_key="original_goal", ), - RewardTransform=HERRewardTransform(), + reward_assigner=HERRewardAssigner(), ) done = torch.zeros(*batch_size, 1, dtype=torch.bool, device=device) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d096a618e3c..2a9b8df4be2 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -67,7 +67,7 @@ FrameSkipTransform, GrayScale, gSDENoise, - HERRewardTransform, + HERRewardAssigner, HERSubGoalAssigner, HERSubGoalSampler, HindsightExperienceReplayTransform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index d7075d6e30f..1ed61db813a 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -31,7 +31,7 @@ FrameSkipTransform, GrayScale, gSDENoise, - HERRewardTransform, + HERRewardAssigner, HERSubGoalAssigner, HERSubGoalSampler, HindsightExperienceReplayTransform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f81562d7f8a..9d4ff56df2b 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9269,17 +9269,22 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: class HERSubGoalSampler(Transform): - """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] representing the subgoal index. Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. + """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. + + Available strategies are: `final` and `future`. The `final` strategy assigns the last state of the trajectory as the subgoal. The `future` strategy samples up to `num_samples` subgoal from all intermediate states within the same trajectory. Args: num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. - out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". + subgoal_idx_key (NestedKey): The key to store the subgoal index. Defaults to "subgoal_idx". + strategy (str): Specifies the subgoal sampling strategy `"final"` | `"future"`. Defaults to `"future"`. + + seealso:: `HindsightExperienceReplayTransform`, `HERSubgoalSampler`, `HERSubGoalAssigner`, `HERRewardAssigner`. """ def __init__( self, num_samples: int = 4, - subgoal_idx_key: str = "subgoal_idx", + subgoal_idx_key: NestedKey = "subgoal_idx", strategy: str = "future", ): super().__init__( @@ -9292,14 +9297,25 @@ def __init__( self.strategy = strategy def forward(self, trajectories: TensorDictBase) -> TensorDictBase: + assert len(trajectories.shape) in [1, 2] + assert self.strategy in ["final", "future"] + if len(trajectories.shape) == 1: trajectories = trajectories.unsqueeze(0) batch_size, trajectory_len = trajectories.shape - if self.strategy == "last": + if self.strategy == "final": return TensorDict( - {"subgoal_idx": torch.full((batch_size, 1), -2)}, batch_size=batch_size + { + self.subgoal_idx_key: torch.full( + (batch_size, 1), + -2, + dtype=torch.int64, + device=trajectories.device, + ) + }, + batch_size=batch_size, ) else: @@ -9308,9 +9324,14 @@ def forward(self, trajectories: TensorDictBase) -> TensorDictBase: subgoal_idxs.append( TensorDict( { - "subgoal_idx": (torch.randperm(trajectory_len - 2) + 1)[ - : self.num_samples - ] + self.subgoal_idx_key: ( + torch.randperm( + trajectory_len - 2, + dtype=torch.int64, + device=trajectories.device, + ) + + 1 + )[: self.num_samples] }, batch_size=torch.Size(), ) @@ -9324,12 +9345,14 @@ class HERSubGoalAssigner(Transform): Args: subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". + + seealso:: `HindsightExperienceReplayTransform`, `HERSubgoalSampler`, `HERRewardAssigner`. """ def __init__( self, - achieved_goal_key: str | tuple = "achieved_goal", - desired_goal_key: str | tuple = "desired_goal", + achieved_goal_key: NestedKey = "achieved_goal", + desired_goal_key: NestedKey = "desired_goal", ): self.achieved_goal_key = achieved_goal_key self.desired_goal_key = desired_goal_key @@ -9359,11 +9382,13 @@ def forward( return trajectories -class HERRewardTransform(Transform): +class HERRewardAssigner(Transform): """This module assigns a reward of `reward_value` where the new trajectory `(next, done)` is `True`. Args: reward_value (float): The reward to be assigned to the newly generated trajectories. Defaults to "1.0". + + seealso:: `HindsightExperienceReplayTransform`, `HERSubgoalSampler`, `HERSubGoalAssigner`. """ def __init__( @@ -9391,23 +9416,31 @@ class HindsightExperienceReplayTransform(Transform): SubGoalSampler (Transform): SubGoalAssigner (Transform): RewardTransform (Transform): + + seealso:: `HERSubgoalSampler`, `HERSubGoalAssigner`, `HERRewardAssigner`. """ def __init__( self, - SubGoalSampler: Transform = HERSubGoalSampler(), - SubGoalAssigner: Transform = HERSubGoalAssigner(), - RewardTransform: Transform = HERRewardTransform(), + subgoal_sampler: Transform | None = None, + subgoal_assigner: Transform | None = None, + reward_assigner: Transform | None = None, assign_subgoal_idxs: bool = False, ): + if subgoal_sampler is None: + subgoal_sampler = HERSubGoalSampler() + if subgoal_assigner is None: + subgoal_assigner = HERSubGoalAssigner() + if reward_assigner is None: + reward_assigner = HERRewardAssigner() super().__init__( in_keys=None, in_keys_inv=None, out_keys_inv=None, ) - self.SubGoalSampler = SubGoalSampler - self.SubGoalAssigner = SubGoalAssigner - self.RewardTransform = RewardTransform + self.subgoal_sampler = subgoal_sampler + self.subgoal_assigner = subgoal_assigner + self.reward_assigner = reward_assigner self.assign_subgoal_idxs = assign_subgoal_idxs def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -9431,17 +9464,17 @@ def her_augmentation(self, trajectories: TensorDictBase): new_trajectories = trajectories.clone(True) # Sample subgoal indices - subgoal_idxs = self.SubGoalSampler(new_trajectories) + subgoal_idxs = self.subgoal_sampler(new_trajectories) # Create new trajectories augmented_trajectories = [] list_idxs = [] for i in range(batch_size): - idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] + idxs = subgoal_idxs[i][self.subgoal_sampler.subgoal_idx_key] if "masks" in subgoal_idxs.keys(): idxs = idxs[ - subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key] + subgoal_idxs[i]["masks", self.subgoal_sampler.subgoal_idx_key] ] list_idxs.append(idxs.unsqueeze(-1)) @@ -9452,7 +9485,7 @@ def her_augmentation(self, trajectories: TensorDictBase): ) if self.assign_subgoal_idxs: - new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze( + new_traj[self.subgoal_sampler.subgoal_idx_key] = idxs.unsqueeze( -1 ).repeat(1, trajectory_length) @@ -9461,11 +9494,11 @@ def her_augmentation(self, trajectories: TensorDictBase): associated_idxs = torch.cat(list_idxs, dim=0) # Assign subgoals to the new trajectories - augmented_trajectories = self.SubGoalAssigner.forward( + augmented_trajectories = self.subgoal_assigner.forward( augmented_trajectories, associated_idxs ) # Adjust the rewards based on the new subgoals - augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) + augmented_trajectories = self.reward_assigner.forward(augmented_trajectories) return augmented_trajectories