-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathObjDict.py
132 lines (117 loc) · 4.85 KB
/
ObjDict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Dict, Any, Optional
from copy import deepcopy
class ObjDict(dict):
@property
def NotExist(self): # for default value
return ObjDict.NotExist
def __init__(self, d: dict = None, recursive=True, default=NotExist, *, antiloop_map=None):
'''
## ObjDict is a subclass of dict that allows for object-like access
#### Preserved:
these preserved names are not allowed to be set using dot access,
but you can access your version using `['name']` or `get`
* `NotExist`: default value for missing key, will raise KeyError
* `update`: just like dict.update(), but recursively converts nested dicts
* `copy`: returns a shallow copy
* Any attribute of the dict class
* Any name starts with `_`
#### Precedence:
* `.` : Attribute > Key > Default
* `[]` & `get` : Key > Default
#### Params:
* `d`: dict
* `default`: default value to return if key is not found,
reset to ObjDict.NotExist to raise KeyError
* `recursive`: recursively try to convert all sub-objects in `d`
* `antiloop_map`: a dict to store the loop-detection,
if you want to use the same ObjDict object in multiple places,
you can pass a dict to `antiloop_map` to avoid infinite loop
'''
self.__dict__["_antiloop_map"] = {
} if antiloop_map is None else antiloop_map # for reference loop safety
self.__dict__["_default"] = default
self.__dict__["_recursive"] = recursive
d = d or {}
self._antiloop_map[id(d)] = self
self.update(d)
def update(self, d, **kw):
try:
if not isinstance(d, dict) or kw:
d = dict(d, **kw)
else:
# create a dummy if not exist yet, prevent infinite-loop
self._convert(d)
for k, v in d.items():
self[k] = self._convert(v)
finally:
self.__dict__["_antiloop_map"] = {} # reset the map
def _convert(self, v: Any, recursive: Optional[bool] = None) -> Any:
recursive = recursive if recursive is not None else self._recursive
if not recursive:
return v
if isinstance(v, dict):
if id(v) in self._antiloop_map:
return self._antiloop_map[id(v)]
elif isinstance(v, ObjDict):
if v.default is not self.default:
v.default = self.default
return v
else:
return ObjDict(v, default=self.default, antiloop_map=self._antiloop_map)
elif isinstance(v, list):
return [self._convert(i) for i in v]
elif isinstance(v, tuple):
return tuple(self._convert(i) for i in v)
elif isinstance(v, set):
return set(self._convert(i) for i in v)
else:
return v
@property
def default(self):
return self.__dict__["_default"]
@default.setter
def default(self, value):
"""
### default property
NOTICE: will also set default value for all sub-dicts
* set value to return when key is not found
* set to `ObjDict.NotExist` to raise KeyError when key is not found
* when set to a mutable object, it will be deepcopied before being set
"""
self.__dict__["_default"] = value
self.update(self)
def copy(self) -> ObjDict:
"""### returns a shallow copy"""
return ObjDict(self, recursive=False, default=self.default)
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(f"{name} not found in {self}")
def __setattr__(self, name: str, value):
if name in {"NotExist", "update", "copy"} or name.startswith("_"):
raise AttributeError(
f"set '{name}' with dot access is not allowed, consider using ['{name}']")
# cannot just call setattr(self, name, value), recursion error
if name in self.__dict__:
self.__dict__[name] = value
elif hasattr(getattr(type(self), name, None), "__set__"):
getattr(type(self), name).__set__(self, value)
else:
self[name] = value
def __getitem__(self, name: str):
if name in self:
return self.get(name)
elif self.default is ObjDict.NotExist:
raise KeyError(f"{name} not found in {self}")
else:
self[name] = deepcopy(self.default)
return self[name]
def __deepcopy__(self, memo: Dict[int, Any]):
copy = ObjDict({}, recursive=self.__dict__["_recursive"], default=self.default)
memo[id(self)] = copy
dummy = deepcopy(dict(self), memo)
copy.update(dummy)
return copy