Skip to content

Commit

Permalink
Add skip_recursion config
Browse files Browse the repository at this point in the history
  • Loading branch information
gaogaotiantian committed Nov 24, 2021
1 parent c56188d commit 58a3de8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/objprint/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def __str__(self):
return json.dumps(_objprint.objjson(self), **kwargs)
else:
def __str__(self):
return _objprint._get_custom_object_str(self, indent_level=0, cfg=_objprint._configs.overwrite(**kwargs))
cfg = _objprint._configs.overwrite(**kwargs)
memo = set() if cfg.skip_recursion else None
return _objprint._get_custom_object_str(self, memo, indent_level=0, cfg=cfg)

if orig_class is None:
def wrapper(cls):
Expand Down
26 changes: 17 additions & 9 deletions src/objprint/objprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

class _PrintConfig:
indent = 2
depth = 6
depth = 100
width = 80
color = True
label = []
elements = -1
exclude = []
include = []
skip_recursion = True
honor_existing = True

def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,9 +72,11 @@ def objstr(self, obj, **kwargs):
# If no color option is specified, don't use color
if "color" not in kwargs:
kwargs["color"] = False
return self._objstr(obj, indent_level=0, cfg=self._configs.overwrite(**kwargs))
cfg = self._configs.overwrite(**kwargs)
memo = set() if cfg.skip_recursion else None
return self._objstr(obj, memo, indent_level=0, cfg=cfg)

def _objstr(self, obj, indent_level, cfg):
def _objstr(self, obj, memo, indent_level, cfg):
# If it's builtin type, return it directly
if isinstance(obj, str):
return f"'{obj}'"
Expand All @@ -85,14 +88,19 @@ def _objstr(self, obj, indent_level, cfg):
return f"<function {obj.__name__}>"

# Otherwise we may need to unpack it. Figure out if we should do that first
if indent_level >= cfg.depth:
if (memo is not None and id(obj) in memo) or \
(cfg.depth is not None and indent_level >= cfg.depth):
return self._get_ellipsis(obj, cfg)

if memo is not None:
memo = memo.copy()
memo.add(id(obj))

if isinstance(obj, list) or isinstance(obj, tuple) or isinstance(obj, set):
elems = (f"{self._objstr(val, indent_level + 1, cfg)}" for val in obj)
elems = (f"{self._objstr(val, memo, indent_level + 1, cfg)}" for val in obj)
elif isinstance(obj, dict):
elems = (
f"{self._objstr(key, indent_level + 1, cfg)}: {self._objstr(val, indent_level + 1, cfg)}"
f"{self._objstr(key, None, indent_level + 1, cfg)}: {self._objstr(val, memo, indent_level + 1, cfg)}"
for key, val in sorted(obj.items())
)
else:
Expand All @@ -106,7 +114,7 @@ def _objstr(self, obj, indent_level, cfg):
lines = s.split("\n")
lines[1:] = [self.add_indent(line, indent_level, cfg) for line in lines[1:]]
return "\n".join(lines)
return self._get_custom_object_str(obj, indent_level, cfg)
return self._get_custom_object_str(obj, memo, indent_level, cfg)

return self._get_pack_str(elems, obj, indent_level, cfg)

Expand Down Expand Up @@ -140,10 +148,10 @@ def _objjson(self, obj, memo):

return ret

def _get_custom_object_str(self, obj, indent_level, cfg):
def _get_custom_object_str(self, obj, memo, indent_level, cfg):

def _get_line(key):
val = self._objstr(obj.__dict__[key], indent_level + 1, cfg)
val = self._objstr(obj.__dict__[key], memo, indent_level + 1, cfg)
if cfg.label and any(re.fullmatch(pattern, key) is not None for pattern in cfg.label):
return set_color(f".{key} = {val}", COLOR.YELLOW)
elif cfg.color:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_objstr.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,13 @@ def test_color(self):
s = objstr(ObjTest({"a": 1}), color=True)
self.assertIn("\033", s)
config(color=False)

def test_recursion(self):
t1 = ObjTest({})
t2 = ObjTest({"t1": t1})
t2.t1.t2 = t2
s = objstr(t2)
self.assertIn("...", s)
self.assertEqual(s.count("t2"), 1)
s = objstr(t2, skip_recursion=False, depth=6)
self.assertEqual(s.count("t2"), 3)

0 comments on commit 58a3de8

Please sign in to comment.