diff --git a/src/monty/collections.py b/src/monty/collections.py index 812bff83..154a9c86 100644 --- a/src/monty/collections.py +++ b/src/monty/collections.py @@ -1,17 +1,27 @@ """ -Useful collection classes, e.g., tree, frozendict, etc. +Useful collection classes: + - tree: A recursive `defaultdict` for creating nested dictionaries + with default values. + - ControlledDict: A base dict class with configurable mutability. + - frozendict: An immutable dictionary. + - Namespace: A dict doesn't allow changing values, but could + add new keys, + - AttrDict: A dict whose values could be access as `dct.key`. + - FrozenAttrDict: An immutable version of `AttrDict`. + - MongoDict: A dict-like object whose values are nested dicts + could be accessed as attributes. """ from __future__ import annotations import collections +import warnings +from abc import ABC from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any, Iterable - from typing_extensions import Self - def tree() -> collections.defaultdict: """ @@ -20,7 +30,7 @@ def tree() -> collections.defaultdict: Usage: x = tree() - x['a']['b']['c'] = 1 + x["a"]["b"]["c"] = 1 Returns: A tree. @@ -28,118 +38,194 @@ def tree() -> collections.defaultdict: return collections.defaultdict(tree) -class frozendict(dict): +class ControlledDict(collections.UserDict, ABC): """ - A dictionary that does not permit changes. The naming - violates PEP8 to be consistent with standard Python's "frozenset" naming. + A base dictionary class with configurable mutability. + + Attributes: + _allow_add (bool): Whether new keys can be added. + _allow_del (bool): Whether existing keys can be deleted. + _allow_update (bool): Whether existing keys can be updated. + + Configurable Operations: + This class allows controlling the following dict operations (refer to + https://docs.python.org/3.13/library/stdtypes.html#mapping-types-dict for details): + + - Adding or updating items: + - setter method: `__setitem__` + - `setdefault` + - `update` + + - Deleting items: + - `del dict[key]` + - `pop(key)` + - `popitem` + - `clear()` """ - def __init__(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. - """ - dict.__init__(self, *args, **kwargs) + _allow_add: bool = True + _allow_del: bool = True + _allow_update: bool = True - def __setitem__(self, key: Any, val: Any) -> None: - raise KeyError(f"Cannot overwrite existing key: {str(key)}") + def __init__(self, *args, **kwargs) -> None: + """Temporarily allow add during initialization.""" + original_allow_add = self._allow_add - def update(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. - """ - raise KeyError(f"Cannot update a {self.__class__.__name__}") + try: + self._allow_add = True + super().__init__(*args, **kwargs) + finally: + self._allow_add = original_allow_add + # Override add/update operations + def __setitem__(self, key, value) -> None: + """Forbid adding or updating keys based on _allow_add and _allow_update.""" + if key not in self.data and not self._allow_add: + raise TypeError(f"Cannot add new key {key!r}, because add is disabled.") + elif key in self.data and not self._allow_update: + raise TypeError(f"Cannot update key {key!r}, because update is disabled.") -class Namespace(dict): - """A dictionary that does not permit to redefine its keys.""" + super().__setitem__(key, value) - def __init__(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. + def update(self, *args, **kwargs) -> None: + """Forbid adding or updating keys based on _allow_add and _allow_update.""" + for key in dict(*args, **kwargs): + if key not in self.data and not self._allow_add: + raise TypeError( + f"Cannot add new key {key!r} using update, because add is disabled." + ) + elif key in self.data and not self._allow_update: + raise TypeError( + f"Cannot update key {key!r} using update, because update is disabled." + ) + + super().update(*args, **kwargs) + + def setdefault(self, key, default=None) -> Any: + """Forbid adding or updating keys based on _allow_add and _allow_update. + + Note: if not _allow_update, this method would NOT check whether the + new default value is the same as current value for efficiency. """ - self.update(*args, **kwargs) + if key not in self.data: + if not self._allow_add: + raise TypeError( + f"Cannot add new key using setdefault: {key!r}, because add is disabled." + ) + elif not self._allow_update: + raise TypeError( + f"Cannot update key using setdefault: {key!r}, because update is disabled." + ) + + return super().setdefault(key, default) + + # Override delete operations + def __delitem__(self, key) -> None: + """Forbid deleting keys when self._allow_del is False.""" + if not self._allow_del: + raise TypeError(f"Cannot delete key {key!r}, because delete is disabled.") + super().__delitem__(key) + + def pop(self, key, *args): + """Forbid popping keys when self._allow_del is False.""" + if not self._allow_del: + raise TypeError(f"Cannot pop key {key!r}, because delete is disabled.") + return super().pop(key, *args) + + def popitem(self): + """Forbid popping the last item when self._allow_del is False.""" + if not self._allow_del: + raise TypeError("Cannot pop item, because delete is disabled.") + return super().popitem() + + def clear(self) -> None: + """Forbid clearing the dictionary when self._allow_del is False.""" + if not self._allow_del: + raise TypeError("Cannot clear dictionary, because delete is disabled.") + super().clear() + + +class frozendict(ControlledDict): + """ + A dictionary that does not permit changes. The naming + violates PEP 8 to be consistent with the built-in `frozenset` naming. + """ - def __setitem__(self, key: Any, val: Any) -> None: - if key in self: - raise KeyError(f"Cannot overwrite existent key: {key!s}") + _allow_add: bool = False + _allow_del: bool = False + _allow_update: bool = False - dict.__setitem__(self, key, val) - def update(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. - """ - for k, v in dict(*args, **kwargs).items(): - self[k] = v +class Namespace(ControlledDict): + """A dictionary that does not permit update/delete its values (but allows add).""" + + _allow_add: bool = True + _allow_del: bool = False + _allow_update: bool = False class AttrDict(dict): """ - Allows to access dict keys as obj.foo in addition - to the traditional way obj['foo']" + Allow accessing values as `dct.key` in addition to the traditional way `dct["key"]`. Examples: - >>> d = AttrDict(foo=1, bar=2) - >>> assert d["foo"] == d.foo - >>> d.bar = "hello" - >>> assert d.bar == "hello" + >>> dct = AttrDict(foo=1, bar=2) + >>> assert dct["foo"] is dct.foo + >>> dct.bar = "hello" + + Warnings: + When shadowing dict methods, e.g.: + >>> dct = AttrDict(update="value") + >>> dct.update() # TypeError (the `update` method is overwritten) + + References: + https://stackoverflow.com/a/14620633/24021108 """ def __init__(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. - """ - super().__init__(*args, **kwargs) + super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self - def copy(self) -> Self: - """ - Returns: - Copy of AttrDict - """ - newd = super().copy() - return self.__class__(**newd) + def __setitem__(self, key, value) -> None: + """Check if the key shadows dict method.""" + if key in dir(dict): + warnings.warn( + f"'{key=}' shadows dict method. This may lead to unexpected behavior.", + UserWarning, + stacklevel=2, + ) + super().__setitem__(key, value) class FrozenAttrDict(frozendict): """ A dictionary that: - * does not permit changes. - * Allows to access dict keys as obj.foo in addition - to the traditional way obj['foo'] + - Does not permit changes (add/update/delete). + - Allows accessing values as `dct.key` in addition to + the traditional way `dct["key"]`. """ def __init__(self, *args, **kwargs) -> None: - """ - Args: - args: Passthrough arguments for standard dict. - kwargs: Passthrough keyword arguments for standard dict. - """ + """Allow add during init, as __setattr__ is called unlike frozendict.""" + self._allow_add = True super().__init__(*args, **kwargs) + self._allow_add = False def __getattribute__(self, name: str) -> Any: try: - return super().__getattribute__(name) + return object.__getattribute__(self, name) except AttributeError: - try: - return self[name] - except KeyError as exc: - raise AttributeError(str(exc)) + return self[name] def __setattr__(self, name: str, value: Any) -> None: - raise KeyError( - f"You cannot modify attribute {name} of {self.__class__.__name__}" - ) + if not self._allow_add and name != "_allow_add": + raise TypeError( + f"{self.__class__.__name__} does not support item assignment." + ) + super().__setattr__(name, value) + + def __delattr__(self, name: str) -> None: + raise TypeError(f"{self.__class__.__name__} does not support item deletion.") class MongoDict: @@ -151,11 +237,11 @@ class MongoDict: a nested dict interactively (e.g. documents extracted from a MongoDB database). - >>> m = MongoDict({'a': {'b': 1}, 'x': 2}) - >>> assert m.a.b == 1 and m.x == 2 - >>> assert "a" in m and "b" in m.a - >>> m["a"] - {'b': 1} + >>> m_dct = MongoDict({"a": {"b": 1}, "x": 2}) + >>> assert m_dct.a.b == 1 and m_dct.x == 2 + >>> assert "a" in m_dct and "b" in m_dct.a + >>> m_dct["a"] + {"b": 1} Notes: Cannot inherit from ABC collections.Mapping because otherwise @@ -186,7 +272,6 @@ def __getattribute__(self, name: str) -> Any: try: return super().__getattribute__(name) except AttributeError: - # raise try: a = self._mongo_dict_[name] if isinstance(a, collections.abc.Mapping): @@ -214,26 +299,25 @@ def __dir__(self) -> list: def dict2namedtuple(*args, **kwargs) -> tuple: """ - Helper function to create a class `namedtuple` from a dictionary. + Helper function to create a `namedtuple` from a dictionary. Examples: - >>> t = dict2namedtuple(foo=1, bar="hello") - >>> assert t.foo == 1 and t.bar == "hello" + >>> tpl = dict2namedtuple(foo=1, bar="hello") + >>> assert tpl.foo == 1 and tpl.bar == "hello" - >>> t = dict2namedtuple([("foo", 1), ("bar", "hello")]) - >>> assert t[0] == t.foo and t[1] == t.bar + >>> tpl = dict2namedtuple([("foo", 1), ("bar", "hello")]) + >>> assert tpl[0] is tpl.foo and tpl[1] is tpl.bar Warnings: - The order of the items in the namedtuple is not deterministic if - kwargs are used. - namedtuples, however, should always be accessed by attribute hence - this behaviour should not represent a serious problem. + `kwargs` are used. namedtuples, however, should always be accessed + by attribute hence this behaviour should not be a serious problem. - - Don't use this function in code in which memory and performance are - crucial since a dict is needed to instantiate the tuple! + - Don't use this function in code where memory and performance are + crucial, since a dict is needed to instantiate the tuple! """ - d = collections.OrderedDict(*args) - d.update(**kwargs) + dct = collections.OrderedDict(*args) + dct.update(**kwargs) return collections.namedtuple( - typename="dict2namedtuple", field_names=list(d.keys()) - )(**d) + typename="dict2namedtuple", field_names=list(dct.keys()) + )(**dct) diff --git a/src/monty/design_patterns.py b/src/monty/design_patterns.py index a99a8a6b..074505c9 100644 --- a/src/monty/design_patterns.py +++ b/src/monty/design_patterns.py @@ -7,9 +7,12 @@ import inspect import os from functools import wraps -from typing import Any, Dict, Hashable, Tuple, TypeVar +from typing import TYPE_CHECKING, TypeVar from weakref import WeakValueDictionary +if TYPE_CHECKING: + from typing import Any + def singleton(cls): """ @@ -95,7 +98,7 @@ def new_init(self: Any, *args: Any, **kwargs: Any) -> None: orig_init(self, *args, **kwargs) self._initialized = True - def reduce(self: Any) -> Tuple[type, Tuple, Dict[str, Any]]: + def reduce(self: Any) -> tuple[type, tuple, dict[str, Any]]: for key, value in cache.items(): if value is self: cls, args = key diff --git a/src/monty/fractions.py b/src/monty/fractions.py index d29b7960..f0650e5a 100644 --- a/src/monty/fractions.py +++ b/src/monty/fractions.py @@ -5,7 +5,10 @@ from __future__ import annotations import math -from typing import Sequence +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Sequence def gcd(*numbers: int) -> int: diff --git a/src/monty/functools.py b/src/monty/functools.py index 32e55517..969f0d50 100644 --- a/src/monty/functools.py +++ b/src/monty/functools.py @@ -9,11 +9,11 @@ import signal import sys import tempfile -from collections import namedtuple from functools import partial, wraps -from typing import Any, Callable, Union +from typing import TYPE_CHECKING -_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"]) +if TYPE_CHECKING: + from typing import Any, Callable, Union class _HashedSeq(list): # pylint: disable=C0205 diff --git a/tests/test_collections.py b/tests/test_collections.py index 58d2af9e..118a08a8 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,55 +1,262 @@ from __future__ import annotations -import os +from collections import UserDict import pytest -from monty.collections import AttrDict, FrozenAttrDict, Namespace, frozendict, tree - -TEST_DIR = os.path.join(os.path.dirname(__file__), "test_files") - - -class TestFrozenDict: - def test_frozen_dict(self): - d = frozendict({"hello": "world"}) - with pytest.raises(KeyError): - d["k"] == "v" - assert d["hello"] == "world" - - def test_namespace_dict(self): - d = Namespace(foo="bar") - d["hello"] = "world" - assert d["foo"] == "bar" - with pytest.raises(KeyError): - d.update({"foo": "spam"}) - - def test_attr_dict(self): - d = AttrDict(foo=1, bar=2) - assert d.bar == 2 - assert d["foo"] == d.foo - d.bar = "hello" - assert d["bar"] == "hello" - - def test_frozen_attrdict(self): - d = FrozenAttrDict({"hello": "world", 1: 2}) - assert d["hello"] == "world" - assert d.hello == "world" - with pytest.raises(KeyError): - d["updating"] == 2 - - with pytest.raises(KeyError): - d["foo"] = "bar" - with pytest.raises(KeyError): - d.foo = "bar" - with pytest.raises(KeyError): - d.hello = "new" - - -class TestTree: - def test_tree(self): - x = tree() - x["a"]["b"]["c"]["d"] = 1 - assert "b" in x["a"] - assert "c" not in x["a"] - assert "c" in x["a"]["b"] - assert x["a"]["b"]["c"]["d"] == 1 +from monty.collections import ( + AttrDict, + ControlledDict, + FrozenAttrDict, + MongoDict, + Namespace, + dict2namedtuple, + frozendict, + tree, +) + + +def test_tree(): + x = tree() + x["a"]["b"]["c"]["d"] = 1 + assert "b" in x["a"] + assert "c" not in x["a"] + assert "c" in x["a"]["b"] + assert x["a"]["b"]["c"]["d"] == 1 + + +class TestControlledDict: + def test_add_allowed(self): + dct = ControlledDict(a=1) + dct._allow_add = True + + dct["b"] = 2 + assert dct["b"] == 2 + + dct.update(d=3) + assert dct["d"] == 3 + + dct.setdefault("e", 4) + assert dct["e"] == 4 + + def test_add_disabled(self): + dct = ControlledDict(a=1) + dct._allow_add = False + + with pytest.raises(TypeError, match="add is disabled"): + dct["b"] = 2 + + with pytest.raises(TypeError, match="add is disabled"): + dct.update(b=2) + + with pytest.raises(TypeError, match="add is disabled"): + dct.setdefault("c", 2) + + def test_update_allowed(self): + dct = ControlledDict(a=1) + dct._allow_update = True + + dct["a"] = 2 + assert dct["a"] == 2 + + dct.update({"a": 3}) + assert dct["a"] == 3 + + dct.setdefault("a", 4) # existing key + assert dct["a"] == 3 + + def test_update_disabled(self): + dct = ControlledDict(a=1) + dct._allow_update = False + + with pytest.raises(TypeError, match="update is disabled"): + dct["a"] = 2 + + with pytest.raises(TypeError, match="update is disabled"): + dct.update({"a": 3}) + + with pytest.raises(TypeError, match="update is disabled"): + dct.setdefault("a", 4) + + def test_del_allowed(self): + dct = ControlledDict(a=1, b=2, c=3, d=4) + dct._allow_del = True + + del dct["a"] + assert "a" not in dct + + val = dct.pop("b") + assert val == 2 and "b" not in dct + + val = dct.popitem() + assert val == ("c", 3) and "c" not in dct + + dct.clear() + assert dct == {} + + def test_del_disabled(self): + dct = ControlledDict(a=1) + dct._allow_del = False + + with pytest.raises(TypeError, match="delete is disabled"): + del dct["a"] + + with pytest.raises(TypeError, match="delete is disabled"): + dct.pop("a") + + with pytest.raises(TypeError, match="delete is disabled"): + dct.popitem() + + with pytest.raises(TypeError, match="delete is disabled"): + dct.clear() + + def test_frozen_like(self): + """Make sure add and update are allowed at init time.""" + ControlledDict._allow_add = False + ControlledDict._allow_update = False + + dct = ControlledDict({"hello": "world"}) + assert isinstance(dct, UserDict) + assert dct["hello"] == "world" + + assert not dct._allow_add + assert not dct._allow_update + + +def test_frozendict(): + dct = frozendict({"hello": "world"}) + assert isinstance(dct, UserDict) + assert dct["hello"] == "world" + + assert not dct._allow_add + assert not dct._allow_update + assert not dct._allow_del + + # Test setter + with pytest.raises(TypeError, match="add is disabled"): + dct["key"] = "val" + + # Test update + with pytest.raises(TypeError, match="add is disabled"): + dct.update(key="val") + + # Test pop + with pytest.raises(TypeError, match="delete is disabled"): + dct.pop("key") + + # Test delete + with pytest.raises(TypeError, match="delete is disabled"): + del dct["key"] + + +def test_namespace_dict(): + dct = Namespace(key="val") + assert isinstance(dct, UserDict) + + # Test setter + dct["hello"] = "world" + assert dct["key"] == "val" + + # Test update (not allowed) + with pytest.raises(TypeError, match="update is disabled"): + dct["key"] = "val" + + with pytest.raises(TypeError, match="update is disabled"): + dct.update({"key": "val"}) + + # Test delete (not allowed) + with pytest.raises(TypeError, match="delete is disabled"): + del dct["key"] + + +def test_attr_dict(): + dct = AttrDict(foo=1, bar=2) + + # Test get attribute + assert dct.bar == 2 + assert dct["foo"] is dct.foo + + # Test key not found error + with pytest.raises(KeyError, match="no-such-key"): + dct["no-such-key"] + + # Test setter + dct.bar = "hello" + assert dct["bar"] == "hello" + + # Test delete + del dct.bar + assert "bar" not in dct + + # Test builtin dict method shadowing + with pytest.warns(UserWarning, match="shadows dict method"): + dct["update"] = "value" + + +def test_frozen_attrdict(): + dct = FrozenAttrDict({"hello": "world", 1: 2}) + assert isinstance(dct, UserDict) + + # Test attribute-like operations + with pytest.raises(TypeError, match="does not support item assignment"): + dct.foo = "bar" + + with pytest.raises(TypeError, match="does not support item assignment"): + dct.hello = "new" + + with pytest.raises(TypeError, match="does not support item deletion"): + del dct.hello + + # Test get value + assert dct["hello"] == "world" + assert dct.hello == "world" + assert dct["hello"] is dct.hello # identity check + + # Test adding item + with pytest.raises(TypeError, match="add is disabled"): + dct["foo"] = "bar" + + # Test modifying existing item + with pytest.raises(TypeError, match="update is disabled"): + dct["hello"] = "new" + + # Test update + with pytest.raises(TypeError, match="update is disabled"): + dct.update({"hello": "world"}) + + # Test pop + with pytest.raises(TypeError, match="delete is disabled"): + dct.pop("hello") + + with pytest.raises(TypeError, match="delete is disabled"): + dct.popitem() + + # Test delete + with pytest.raises(TypeError, match="delete is disabled"): + del dct["hello"] + + with pytest.raises(TypeError, match="delete is disabled"): + dct.clear() + + +def test_mongo_dict(): + m_dct = MongoDict({"a": {"b": 1}, "x": 2}) + assert m_dct.a.b == 1 + assert m_dct.x == 2 + assert "a" in m_dct + assert "b" in m_dct.a + assert m_dct["a"] == {"b": 1} + + +def test_dict2namedtuple(): + # Init from dict + tpl = dict2namedtuple(foo=1, bar="hello") + assert isinstance(tpl, tuple) + assert tpl.foo == 1 and tpl.bar == "hello" + + # Init from list of tuples + tpl = dict2namedtuple([("foo", 1), ("bar", "hello")]) + assert isinstance(tpl, tuple) + assert tpl[0] == 1 + assert tpl[1] == "hello" + assert tpl[0] is tpl.foo and tpl[1] is tpl.bar