diff --git a/mypy/stubdoc.py b/mypy/stubdoc.py index 0b5b21e81a0fc..4fd6597930201 100644 --- a/mypy/stubdoc.py +++ b/mypy/stubdoc.py @@ -203,6 +203,36 @@ def args_kwargs(signature: FunctionSig) -> bool: return list(sorted(self.signatures, key=lambda x: 1 if args_kwargs(x) else 0)) +def _infer_escaped_sigs_from_docstring(docstr: str, name: str) -> List[FunctionSig]: + """Parse escaped function signatures at the start of the docstring. + + This function is one half of infer_sig_from_docstring(). + This function finds all the signatures at the start of the docstring + that are separated with a backslash, as is conventional in C extensions. + All other signatures in the docstring, including a single signature at + the start of the docstring that is not escaped, are handled by + _infer_sig_from_docstring(). + """ + sigs = [] + + lines = docstr.splitlines() + # We only want to handle escaped signatures. + # If there isn't one, let _infer_sig_from_docstring handle everything. + if not lines or not lines[0].endswith("\\"): + return sigs + + for line in lines: + escaped = line.endswith("\\") + if escaped: + line = line[:-1].strip() + line_sigs = _infer_sig_from_docstring(line, name) or [] + sigs.extend(line_sigs) + if not escaped: + break + + return sigs + + def infer_sig_from_docstring(docstr: Optional[str], name: str) -> Optional[List[FunctionSig]]: """Convert function signature to list of TypedFunctionSig @@ -221,6 +251,16 @@ def infer_sig_from_docstring(docstr: Optional[str], name: str) -> Optional[List[ if not docstr: return None + other_sigs = _infer_sig_from_docstring(docstr, name) + if other_sigs is None: + return None + + sigs = _infer_escaped_sigs_from_docstring(docstr, name) + sigs.extend(other_sigs) + return sigs + + +def _infer_sig_from_docstring(docstr: str, name: str) -> Optional[List[FunctionSig]]: state = DocStringParser(name) # Return all found signatures, even if there is a parse error after some are found. with contextlib.suppress(tokenize.TokenError): diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 5d62a1af521c3..b6a3f0d0832e2 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -917,6 +917,26 @@ def __init__(self, arg0: str) -> None: 'def __init__(*args, **kwargs) -> Any: ...']) assert_equal(set(imports), {'from typing import overload'}) + def test_generate_autodoc_c_type_with_overload(self) -> None: + class TestClass: + def __init__(self, arg0: str) -> None: + """__init__(self: TestClass, arg0: str) -> None \\ + __init__(self: TestClass, arg0: str, arg1: str) -> None + Overloaded function. + """ + pass + output = [] # type: List[str] + imports = [] # type: List[str] + mod = ModuleType(TestClass.__module__, '') + generate_c_function_stub(mod, '__init__', TestClass.__init__, output, imports, + self_var='self', class_name='TestClass') + assert_equal(output, [ + '@overload', + 'def __init__(self, arg0: str) -> None: ...', + '@overload', + 'def __init__(self, arg0: str, arg1: str) -> None: ...']) + assert_equal(set(imports), {'from typing import overload'}) + class ArgSigSuite(unittest.TestCase): def test_repr(self) -> None: