From 5816aebaf8e290ea373a4be3f85a688c6c0a4674 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 19 Feb 2025 13:36:22 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_storage_map.py | 71 ++++++++++++++++++++++++++++++++++++++++ torchrl/data/map/tree.py | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9866d893ff9..3ab49d9acdc 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -684,6 +684,77 @@ def test_forest_check_obs_match(self, intersect): ).all() prev_tree = subtree + def test_to_string(self): + forest = MCTSForest() + + td_root = TensorDict( + { + "observation": 0, + } + ) + + rollouts_data = [ + # [(action, obs), ...] + [(3, 123), (1, 456)], + [(2, 359), (2, 3094)], + [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)], + [(1, 75)], + [(3, 123), (0, 948)], + [(2, 359), (2, 3094), (10, 68)], + [(2, 359), (2, 3094), (11, 9045)], + ] + + obs_string_check = "\n".join( + [ + "(0,) [123]", + " (0, 0) [456]", + " (0, 1) [392, 989, 809, 847]", + " (0, 2) [948]", + "(1,) [359, 3094]", + " (1, 0) [68]", + " (1, 1) [9045]", + "(2,) [75]", + ] + ) + + action_string_check = "\n".join( + [ + "(0,) [3]", + " (0, 0) [1]", + " (0, 1) [9, 6, 20, 21]", + " (0, 2) [0]", + "(1,) [2, 2]", + " (1, 0) [10]", + " (1, 1) [11]", + "(2,) [1]", + ] + ) + + for rollout_data in rollouts_data: + td = td_root.clone().unsqueeze(0) + for action, obs in rollout_data: + td = td.update( + TensorDict( + { + "action": [action], + "next": TensorDict({"observation": [obs]}, [1]), + }, + [1], + ) + ) + forest.extend(td) + td = td["next"].clone() + + obs_string = forest.to_string( + td_root, lambda tree: tree.rollout["next", "observation"].tolist() + ) + assert obs_string == obs_string_check + + action_string = forest.to_string( + td_root, lambda tree: tree.rollout["action"].tolist() + ) + assert action_string == action_string_check + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index d7fd72869dd..7e7567ff974 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -604,6 +604,42 @@ def plot( f"Unknown plotting backend {backend} with figure {figure}." ) + def to_string(self, node_format_fn): + """Generates a string representation of the tree. + + This function can pull out information from each of the nodes in a tree, + so it can be useful for debugging. The nodes are listed line-by-line. + Each line contains the path to the node, followed by the string + representation of that node generated with :arg:`node_format_fn`. Each + line is indented according to number of steps in the path required to + get to the corresponding node. + + Args: + node_format_fn (Callable): User-defined function to generate a + string for each node of the tree. The signature must be + ``(Tree) -> Any``, and the output must be convertible to + a string. + """ + queue = [ + # tree, path + (self, ()), + ] + + strings = [] + + while len(queue) > 0: + self, path = queue.pop() + if self.subtree is not None: + for subtree_idx, subtree in reversed(list(enumerate(self.subtree))): + queue.append((subtree, path + (subtree_idx,))) + + if self.rollout is not None: + level = len(path) + string = node_format_fn(self) + strings.append(f"{' ' * (level - 1)}{path} {string}") + + return "\n".join(strings) + class MCTSForest: """A collection of MCTS trees. @@ -1164,6 +1200,27 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + def to_string(self, td_root, node_format_fn): + """Generates a string representation of a tree in the forest. + + This function can pull out information from each of the nodes in a tree, + so it can be useful for debugging. The nodes are listed line-by-line. + Each line contains the path to the node, followed by the string + representation of that node generated with :arg:`node_format_fn`. Each + line is indented according to number of steps in the path required to + get to the corresponding node. + + Args: + td_root (TensorDict): Root of the tree. + + node_format_fn (Callable): User-defined function to generate a + string for each node of the tree. The signature must be + ``(Tree) -> Any``, and the output must be convertible to + a string. + """ + tree = self.get_tree(td_root) + return tree.to_string(node_format_fn) + def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: if obj is None: