From 311a66baaa0e4ccde16c4466096097eaf5d86dd6 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 26 Nov 2021 23:57:22 -0800 Subject: [PATCH] Docstring fix for inherited fields + test --- dcargs/_docstrings.py | 74 +++++++++++++++++++++++----------------- dcargs/_serialization.py | 4 ++- tests/test_docstrings.py | 55 +++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 32 deletions(-) diff --git a/dcargs/_docstrings.py b/dcargs/_docstrings.py index 16e30217..7757991a 100644 --- a/dcargs/_docstrings.py +++ b/dcargs/_docstrings.py @@ -25,14 +25,14 @@ class _FieldData: @dataclasses.dataclass -class _Tokenization: +class _ClassTokenization: tokens: List[_Token] tokens_from_line: Dict[int, List[_Token]] field_data_from_name: Dict[str, _FieldData] @staticmethod - @functools.lru_cache(maxsize=4) - def make(cls) -> "_Tokenization": + @functools.lru_cache(maxsize=8) + def make(cls) -> "_ClassTokenization": """Parse the source code of a class, and cache some tokenization information.""" readline = io.BytesIO(inspect.getsource(cls).encode("utf-8")).readline @@ -66,34 +66,57 @@ def make(cls) -> "_Tokenization": ) prev_field_line_number = token.line_number - return _Tokenization( + return _ClassTokenization( tokens=tokens, tokens_from_line=tokens_from_line, field_data_from_name=field_data_from_name, ) +def get_class_tokenization_with_field( + cls: Type, field_name: str +) -> Optional[_ClassTokenization]: + # Search for token in this class + all parents. + found_field: bool = False + classes_to_search = cls.mro() + for search_cls in classes_to_search: + # Unwrap generics. + origin_cls = get_origin(search_cls) + if origin_cls is not None: + search_cls = origin_cls + + # Skip parent classes that aren't dataclasses. + if not dataclasses.is_dataclass(search_cls): + continue + + try: + tokenization = _ClassTokenization.make(search_cls) # type: ignore + except OSError as e: + # Dynamic dataclasses will result in an OSError -- this is fine, we just assume + # there's no docstring. + assert "could not find class definition" in e.args[0] + return None + + # Grab field-specific tokenization data. + if field_name in tokenization.field_data_from_name: + found_field = True + break + + assert ( + found_field + ), "Docstring parsing error -- this usually means that there are multiple \ + dataclasses in the same file with the same name but different scopes." + + return tokenization + + def get_field_docstring(cls: Type, field_name: str) -> Optional[str]: """Get docstring for a field in a class.""" - origin_cls = get_origin(cls) - if origin_cls is not None: - cls = origin_cls - - assert dataclasses.is_dataclass(cls) - try: - tokenization = _Tokenization.make(cls) # type: ignore - except OSError as e: - # Dynamic dataclasses will result in an OSError -- this is fine, we just assume - # there's no docstring. - assert "could not find class definition" in e.args[0] + tokenization = get_class_tokenization_with_field(cls, field_name) + if tokenization is None: # Currently only happens for dynamic dataclasses. return None - # Grab field-specific tokenization data. - assert ( - field_name in tokenization.field_data_from_name - ), "Docstring parsing error -- this usually means that there are multiple \ - dataclasses in the same file with the same name but different scopes." field_data = tokenization.field_data_from_name[field_name] # Check for docstring-style comment. @@ -126,17 +149,6 @@ def get_field_docstring(cls: Type, field_name: str) -> Optional[str]: break line_number += 1 - # if ( - # field_data.line_number + 1 in tokenization.tokens_from_line - # and len(tokenization.tokens_from_line[field_data.line_number + 1]) > 0 - # ): - # first_token_on_next_line = tokenization.tokens_from_line[ - # field_data.line_number + 1 - # ][0] - # if first_token_on_next_line.token_type == tokenize.STRING: - # docstring = first_token_on_next_line.token.strip() - # assert docstring.endswith('"""') and docstring.startswith('"""') - # return _strings.dedent(docstring[3:-3]) # Check for comment on the same line as the field. final_token_on_line = tokenization.tokens_from_line[field_data.line_number][-1] diff --git a/dcargs/_serialization.py b/dcargs/_serialization.py index 44d6824e..f91cd771 100644 --- a/dcargs/_serialization.py +++ b/dcargs/_serialization.py @@ -120,9 +120,11 @@ class DataclassDumper(yaml.Dumper): contained_types = list(_get_contained_special_types_from_instance(instance)) contained_type_names = list(map(lambda cls: cls.__name__, contained_types)) + + # Note: this is currently a stricter than necessary assert. assert len(set(contained_type_names)) == len( contained_type_names - ), f"Contained dataclass type names must all be unique, but got {contained_type_names}" + ), f"Contained dataclass/enum names must all be unique, but got {contained_type_names}" dumper: yaml.Dumper data: Any diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index b61a29c2..3741ad21 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -94,6 +94,61 @@ class HelptextHardString: ) +def test_helptext_with_inheritance(): + @dataclasses.dataclass + class Parent: + # fmt: off + x: str = ( + "This docstring may be tougher to parse!" + ) + """Helptext.""" + # fmt: on + + @dataclasses.dataclass + class Child(Parent): + pass + + f = io.StringIO() + with pytest.raises(SystemExit): + with contextlib.redirect_stdout(f): + dcargs.parse(Child, args=["--help"]) + helptext = f.getvalue() + assert ( + "--x STR Helptext. (default: This docstring may be tougher to parse!)\n" + in helptext + ) + + +def test_helptext_with_inheritance_overriden(): + @dataclasses.dataclass + class Parent2: + # fmt: off + x: str = ( + "This docstring may be tougher to parse!" + ) + """Helptext.""" + # fmt: on + + @dataclasses.dataclass + class Child2(Parent2): + # fmt: off + x: str = ( + "This docstring may be tougher to parse?" + ) + """Helptext.""" + # fmt: on + + f = io.StringIO() + with pytest.raises(SystemExit): + with contextlib.redirect_stdout(f): + dcargs.parse(Child2, args=["--help"]) + helptext = f.getvalue() + assert ( + "--x STR Helptext. (default: This docstring may be tougher to parse?)\n" + in helptext + ) + + def test_tuple_helptext(): @dataclasses.dataclass class TupleHelptext: