From 12b280dd58e8ff5ddfa2d3c29972e876626d7ef0 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 7 Sep 2024 10:16:04 -0400 Subject: [PATCH] Add match_names function to filter specific names --- datatree/datatree.py | 46 +++++++++++++++++++++++++++++++++ datatree/tests/test_datatree.py | 21 +++++++++++++++ docs/source/api.rst | 1 + docs/source/whats-new.rst | 2 ++ 4 files changed, 70 insertions(+) diff --git a/datatree/datatree.py b/datatree/datatree.py index c86c2e2e..49825e8f 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1307,6 +1307,52 @@ def match(self, pattern: str) -> DataTree: } return DataTree.from_dict(matching_nodes, name=self.root.name) + def match_names(self, names: Iterable[str]) -> DataTree: + """ + Filter nodes by name. + + Parameters + ---------- + names: Iterable[str] + The list of node names to retain. + + Returns + ------- + DataTree + + See Also + -------- + match + filter + pipe + map_over_subtree + + Examples + -------- + >>> dt = DataTree.from_dict( + ... { + ... "/a/A": None, + ... "/a/B": None, + ... "/a/C": None, + ... "/C/D": None, + ... "/E/F": None, + ... } + ... ) + >>> dt.match_names(["A", "C"]) + DataTree('None', parent=None) + ├── DataTree('a') + │ └── DataTree('A') + │ └── DataTree('C') + └── DataTree('C') + """ + names = set(names) + matching_nodes = { + node.path: node.ds + for node in self.subtree + if node.name in names + } + return DataTree.from_dict(matching_nodes, name=self.root.name) + def map_over_subtree( self, func: Callable, diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index e9f373d7..01fbf24f 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -707,6 +707,27 @@ def test_match(self): ) dtt.assert_identical(result, expected) + def test_match_names(self): + # TODO is this example going to cause problems with case sensitivity? + dt = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/a/C": None, + "/C/D": None, + "/E/F": None, + } + ) + result = dt.match_names(["A", "C"]) + expected = DataTree.from_dict( + { + "/a/A": None, + "/a/C": None, + "/C": None, + } + ) + dtt.assert_identical(result, expected) + def test_filter(self): simpsons = DataTree.from_dict( d={ diff --git a/docs/source/api.rst b/docs/source/api.rst index d325d24f..24920540 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -103,6 +103,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure. map_over_subtree DataTree.pipe DataTree.match + DataTree.match_names DataTree.filter Pathlib-like Interface diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 2f6e4f88..92ef525c 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -23,6 +23,8 @@ v0.0.14 (unreleased) New Features ~~~~~~~~~~~~ +- Added `DataTree.match_names` method to filter a list of specific node names. + Breaking changes ~~~~~~~~~~~~~~~~