Skip to content

Commit

Permalink
Merge pull request #164 from jacanchaplais/feature/how-leaves-163
Browse files Browse the repository at this point in the history
Flatten MaskGroup to leaves #163
  • Loading branch information
jacanchaplais authored Oct 11, 2023
2 parents b72b459 + d0c43c9 commit b364c2a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
48 changes: 29 additions & 19 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,26 +914,31 @@ def recursive_drop(
return mask_group

def flatten(
self, how: ty.Literal["rise", "agg"] = "rise"
self, how: ty.Literal["rise", "agg", "leaves"] = "rise"
) -> "MaskGroup[MaskArray]":
"""Removes nesting such that the ``MaskGroup`` contains only
``MaskArray`` instances, and no other ``MaskGroup``.
.. versionadded:: 0.1.11
.. versionchanged:: 0.2.6
Added ``'how'`` parameter.
Added ``how`` parameter.
.. versionchanged:: 0.3.7
Added ``leaves`` option for ``how`` parameter.
Parameters
----------
how : {'rise', 'agg'}
Method used to convert into flat ``MaskGroup``. ``'rise'``
how : {'rise', 'agg', 'leaves'}
Method used to convert into flat ``MaskGroup``. ``rise``
recurses through nested levels, raising all contained
``MaskArray`` instances to the top level. ``'agg'`` loops
over the top level of ``MaskBase`` objects, leaving
top-level ``MaskArray`` objects as-is, but calling the
aggregation operation over any ``MaskGroup``. Default is
``'rise'``.
``MaskArray`` instances to the top level.
``agg`` loops over the top level of ``MaskBase`` objects,
leaving top-level ``MaskArray`` objects as-is, but calling
the aggregation operation over any ``MaskGroup``.
``leaves`` brings the innermosted nested ``MaskArray``
instances to the top level, discarding the rest.
Default is ``rise``.
Returns
-------
Expand All @@ -942,26 +947,31 @@ def flatten(
at the top level.
"""

def leaves(
if how == "leaves":
from graphicle.select import leaf_masks

return leaf_masks(self)
if how == "agg":
return self.__class__(
cl.OrderedDict(
zip(self.keys(), map(op.attrgetter("data"), self.values()))
),
"or",
)

def visit(
mask_group: "MaskGroup",
) -> ty.Iterator[ty.Tuple[str, base.MaskLike]]:
for key, val in mask_group.items():
if key == "latent":
continue
if isinstance(val, type(self)):
yield key, val.data
yield from leaves(val)
yield from visit(val)
else:
yield key, val

if how == "rise":
return self.__class__(cl.OrderedDict(leaves(self)), "or") # type: ignore
return self.__class__(
cl.OrderedDict(
zip(self.keys(), map(op.attrgetter("data"), self.values()))
),
"or",
)
return self.__class__(cl.OrderedDict(visit(self)), self.agg_op)

def serialize(self) -> ty.Dict[str, ty.Any]:
"""Returns serialized data as a dictionary.
Expand Down
8 changes: 5 additions & 3 deletions graphicle/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,7 @@ def _leaf_mask_iter(
for name, mask in branch.items():
if exclude_latent and name == "latent":
continue
# TODO: look into contravariant type for this
yield from _leaf_mask_iter(name, mask, exclude_latent) # type: ignore
yield from _leaf_mask_iter(name, mask, exclude_latent)


def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
Expand All @@ -882,6 +881,9 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
.. versionadded:: 0.1.11
.. versionchanged 0.3.7
Output ``MaskGroup`` matches agg_op of ``mask_tree``.
Parameters
----------
mask_tree : MaskGroup
Expand All @@ -893,7 +895,7 @@ def leaf_masks(mask_tree: gcl.MaskGroup) -> gcl.MaskGroup[gcl.MaskArray]:
MaskGroup
Flat ``MaskGroup`` of only the leaves of ``mask_tree``.
"""
mask_group = gcl.MaskGroup(agg_op="or") # type: ignore
mask_group = gcl.MaskGroup(agg_op=mask_tree.agg_op)
for name, branch in mask_tree.items():
mask_group.update(dict(_leaf_mask_iter(name, branch))) # type: ignore
return mask_group
Expand Down

0 comments on commit b364c2a

Please sign in to comment.