diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9866d893ff9..f56743c69c2 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -684,6 +684,93 @@ def test_forest_check_obs_match(self, intersect): ).all() prev_tree = subtree + def test_to_string(self): + forest = MCTSForest() + + td_root = TensorDict( + { + "observation": 0, + } + ) + + rollouts_data = [ + # [(action, obs), ...] + [(3, 123), (1, 456)], + [(2, 359), (2, 3094)], + [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], + [(1, 75)], + [(3, 123), (0, 948)], + [(2, 359), (2, 3094), (10, 68)], + [(2, 359), (2, 3094), (11, 9045)], + ] + + default_string_check = "\n".join( + [ + "(0,) {'observation': tensor(123)}", + " (0, 0) {'observation': tensor(456)}", + " (0, 1) {'observation': tensor(847)}", + " (0, 2) {'observation': tensor(948)}", + "(1,) {'observation': tensor(3094)}", + " (1, 0) {'observation': tensor(68)}", + " (1, 1) {'observation': tensor(9045)}", + "(2,) {'observation': tensor(75)}", + ] + ) + + obs_string_check = "\n".join( + [ + "(0,) [123]", + " (0, 0) [456]", + " (0, 1) [392, 989, 809, 847]", + " (0, 2) [948]", + "(1,) [359, 3094]", + " (1, 0) [68]", + " (1, 1) [9045]", + "(2,) [75]", + ] + ) + + action_string_check = "\n".join( + [ + "(0,) [3]", + " (0, 0) [1]", + " (0, 1) [9, 6, 20, 21]", + " (0, 2) [0]", + "(1,) [2, 2]", + " (1, 0) [10]", + " (1, 1) [11]", + "(2,) [1]", + ] + ) + + for rollout_data in rollouts_data: + td = td_root.clone().unsqueeze(0) + for action, obs in rollout_data: + td = td.update( + TensorDict( + { + "action": [action], + "next": TensorDict({"observation": [obs]}, [1]), + }, + [1], + ) + ) + forest.extend(td) + td = td["next"].clone() + + default_string = forest.to_string(td_root) + assert default_string == default_string_check + + obs_string = forest.to_string( + td_root, lambda tree: tree.rollout["next", "observation"].tolist() + ) + assert obs_string == obs_string_check + + action_string = forest.to_string( + td_root, lambda tree: tree.rollout["action"].tolist() + ) + assert action_string == action_string_check + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index d7fd72869dd..2e42cd7864d 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -604,6 +604,79 @@ def plot( f"Unknown plotting backend {backend} with figure {figure}." ) + def to_string(self, node_format_fn=lambda tree: tree.node_data.to_dict()): + """Generates a string representation of the tree. + + This function can pull out information from each of the nodes in a tree, + so it can be useful for debugging. The nodes are listed line-by-line. + Each line contains the path to the node, followed by the string + representation of that node generated with :arg:`node_format_fn`. Each + line is indented according to number of steps in the path required to + get to the corresponding node. + + Args: + node_format_fn (Callable, optional): User-defined function to + generate a string for each node of the tree. The signature must + be ``(Tree) -> Any``, and the output must be convertible to a + string. If this argument is not given, the generated string is + the node's :attr:`Tree.node_data` attribute converted to a dict. + + Examples: + >>> from torchrl.data import MCTSForest + >>> from tensordict import TensorDict + >>> forest = MCTSForest() + >>> td_root = TensorDict({"observation": 0,}) + >>> rollouts_data = [ + ... # [(action, obs), ...] + ... [(3, 123), (1, 456)], + ... [(2, 359), (2, 3094)], + ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], + ... [(1, 75)], + ... [(3, 123), (0, 948)], + ... [(2, 359), (2, 3094), (10, 68)], + ... [(2, 359), (2, 3094), (11, 9045)], + ... ] + >>> for rollout_data in rollouts_data: + ... td = td_root.clone().unsqueeze(0) + ... for action, obs in rollout_data: + ... td = td.update(TensorDict({ + ... "action": [action], + ... "next": TensorDict({"observation": [obs]}, [1]), + ... }, [1])) + ... forest.extend(td) + ... td = td["next"].clone() + ... + >>> tree = forest.get_tree(td_root) + >>> print(tree.to_string()) + (0,) {'observation': tensor(123)} + (0, 0) {'observation': tensor(456)} + (0, 1) {'observation': tensor(847)} + (0, 2) {'observation': tensor(948)} + (1,) {'observation': tensor(3094)} + (1, 0) {'observation': tensor(68)} + (1, 1) {'observation': tensor(9045)} + (2,) {'observation': tensor(75)} + """ + queue = [ + # tree, path + (self, ()), + ] + + strings = [] + + while len(queue) > 0: + self, path = queue.pop() + if self.subtree is not None: + for subtree_idx, subtree in reversed(list(enumerate(self.subtree))): + queue.append((subtree, path + (subtree_idx,))) + + if self.rollout is not None: + level = len(path) + string = node_format_fn(self) + strings.append(f"{' ' * (level - 1)}{path} {string}") + + return "\n".join(strings) + class MCTSForest: """A collection of MCTS trees. @@ -1164,6 +1237,63 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()): + """Generates a string representation of a tree in the forest. + + This function can pull out information from each of the nodes in a tree, + so it can be useful for debugging. The nodes are listed line-by-line. + Each line contains the path to the node, followed by the string + representation of that node generated with :arg:`node_format_fn`. Each + line is indented according to number of steps in the path required to + get to the corresponding node. + + Args: + td_root (TensorDict): Root of the tree. + + node_format_fn (Callable, optional): User-defined function to + generate a string for each node of the tree. The signature must + be ``(Tree) -> Any``, and the output must be convertible to a + string. If this argument is not given, the generated string is + the node's :attr:`Tree.node_data` attribute converted to a dict. + + Examples: + >>> from torchrl.data import MCTSForest + >>> from tensordict import TensorDict + >>> forest = MCTSForest() + >>> td_root = TensorDict({"observation": 0,}) + >>> rollouts_data = [ + ... # [(action, obs), ...] + ... [(3, 123), (1, 456)], + ... [(2, 359), (2, 3094)], + ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], + ... [(1, 75)], + ... [(3, 123), (0, 948)], + ... [(2, 359), (2, 3094), (10, 68)], + ... [(2, 359), (2, 3094), (11, 9045)], + ... ] + >>> for rollout_data in rollouts_data: + ... td = td_root.clone().unsqueeze(0) + ... for action, obs in rollout_data: + ... td = td.update(TensorDict({ + ... "action": [action], + ... "next": TensorDict({"observation": [obs]}, [1]), + ... }, [1])) + ... forest.extend(td) + ... td = td["next"].clone() + ... + >>> print(forest.to_string(td_root)) + (0,) {'observation': tensor(123)} + (0, 0) {'observation': tensor(456)} + (0, 1) {'observation': tensor(847)} + (0, 2) {'observation': tensor(948)} + (1,) {'observation': tensor(3094)} + (1, 0) {'observation': tensor(68)} + (1, 1) {'observation': tensor(9045)} + (2,) {'observation': tensor(75)} + """ + tree = self.get_tree(td_root) + return tree.to_string(node_format_fn) + def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: if obj is None: