From 94e93febe3447498591924d9128b7cc2f04932a3 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 7 Sep 2024 10:31:56 -0400 Subject: [PATCH] Add DataTree.match_names to match node names --- xarray/core/datatree.py | 45 +++++++++++++++++++++++++++++++++++ xarray/tests/test_datatree.py | 21 ++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 59984c5afa3..c95e46ede5d 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1349,6 +1349,51 @@ def match(self, pattern: str) -> DataTree: } return DataTree.from_dict(matching_nodes, name=self.root.name) + def match_names(self, names: Iterable[str]) -> DataTree: + """ + Filter nodes by name. + + Parameters + ---------- + names: Iterable[str] + The list of node names to retain. + + Returns + ------- + DataTree + + See Also + -------- + match + filter + pipe + map_over_subtree + + Examples + -------- + >>> dt = DataTree.from_dict( + ... { + ... "/a/A": None, + ... "/a/B": None, + ... "/a/C": None, + ... "/C/D": None, + ... "/E/F": None, + ... } + ... ) + >>> dt.match_names(["A", "C"]) + + Group: / + ├── Group: /C + └── Group: /a + ├── Group: /a/A + └── Group: /a/C + """ + names = set(names) + matching_nodes = { + node.path: node.ds for node in self.subtree if node.name in names + } + return DataTree.from_dict(matching_nodes, name=self.root.name) + def map_over_subtree( self, func: Callable, diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9a15376a1f8..61f8775b102 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1025,6 +1025,27 @@ def test_match(self): ) assert_identical(result, expected) + def test_match_names(self): + # TODO is this example going to cause problems with case sensitivity? + dt: DataTree = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/a/C": None, + "/C/D": None, + "/E/F": None, + } + ) + result = dt.match_names(["A", "C"]) + expected = DataTree.from_dict( + { + "/a/A": None, + "/a/C": None, + "/C": None, + } + ) + assert_identical(result, expected) + def test_filter(self): simpsons: DataTree = DataTree.from_dict( d={