diff --git a/pyproject.toml b/pyproject.toml index cacd6d9f..4eef4854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tyro" -version = "0.3.37" +version = "0.3.38" description = "Strongly typed, zero-effort CLI interfaces" authors = ["brentyi "] include = ["./tyro/**/*"] @@ -76,3 +76,8 @@ exclude_lines = [ # or anything that's deprecated "deprecated", ] + +[tool.ruff] +ignore = [ + "E501", # Ignore line length errors. +] diff --git a/tests/test_conf.py b/tests/test_conf.py index a34508e1..28518777 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Callable, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar, Union import pytest from helptext_utils import get_helptext @@ -496,7 +496,9 @@ def test_suppress_auto_fixed() -> None: @dataclasses.dataclass class Struct: a: int = 5 - b: Callable = lambda x: 5 + + def b(x): + return 5 def main(x: tyro.conf.SuppressFixed[Any] = Struct()): pass diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index da9621e7..a296ceb5 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -48,3 +48,52 @@ class Helptext(BaseModel): assert "Documentation 1" in helptext assert "Documentation 2" in helptext assert "Documentation 3" in helptext + + +def test_pydantic_suppress_base_model_helptext() -> None: + class Helptext(BaseModel): + x: int = Field(description="Documentation 1") + + y: int = Field(description="Documentation 2") + + z: int = Field(description="Documentation 3") + + f = io.StringIO() + with pytest.raises(SystemExit): + with contextlib.redirect_stdout(f): + tyro.cli(Helptext, args=["--help"]) + helptext = f.getvalue() + + assert "Create a new model by parsing and validating" not in helptext + assert "Documentation 1" in helptext + assert "Documentation 2" in helptext + assert "Documentation 3" in helptext + + +class HelptextWithFieldDocstring(BaseModel): + """This docstring should be printed as a description.""" + + x: int + """Documentation 1""" + + y: int = Field(description="Documentation 2") + + z: int = Field(description="Documentation 3") + + +def test_pydantic_field_helptext_from_docstring() -> None: + f = io.StringIO() + with pytest.raises(SystemExit): + with contextlib.redirect_stdout(f): + tyro.cli(HelptextWithFieldDocstring, args=["--help"]) + helptext = f.getvalue() + assert ( + tyro._strings.strip_ansi_sequences( + cast(str, HelptextWithFieldDocstring.__doc__) + ) + in helptext + ) + + assert "Documentation 1" in helptext + assert "Documentation 2" in helptext + assert "Documentation 3" in helptext diff --git a/tyro/_cli.py b/tyro/_cli.py index 346f9cdf..d76c1a9d 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -204,8 +204,10 @@ def _cli_impl( if deprecated_kwargs.get("avoid_subparsers", False): f = conf.AvoidSubcommands[f] # type: ignore warnings.warn( - "`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubparsers[]`" - " instead.", + ( + "`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubparsers[]`" + " instead." + ), stacklevel=2, ) diff --git a/tyro/_docstrings.py b/tyro/_docstrings.py index a01007af..32bafc6e 100644 --- a/tyro/_docstrings.py +++ b/tyro/_docstrings.py @@ -297,6 +297,11 @@ def get_callable_description(f: Callable) -> str: if isinstance(f, functools.partial): f = f.func + try: + import pydantic + except ImportError: + pydantic = None # type: ignore + # Note inspect.getdoc() causes some corner cases with TypedDicts. docstring = f.__doc__ if ( @@ -304,7 +309,10 @@ def get_callable_description(f: Callable) -> str: and isinstance(f, type) # Ignore TypedDict's __init__ docstring, because it will just be `dict` and not is_typeddict(f) + # Ignore NamedTuple __init__ docstring. and not _resolver.is_namedtuple(f) + # Ignore pydantic base model constructor docstring. + and not (pydantic is not None and f.__init__ is pydantic.BaseModel.__init__) # type: ignore ): docstring = f.__init__.__doc__ # type: ignore if docstring is None: diff --git a/tyro/_fields.py b/tyro/_fields.py index 2018a38a..0fe1517c 100644 --- a/tyro/_fields.py +++ b/tyro/_fields.py @@ -415,6 +415,10 @@ def _field_list_from_pydantic( # Handle pydantic models. field_list = [] for pd_field in cls.__fields__.values(): # type: ignore + helptext = pd_field.field_info.description + if helptext is None: + helptext = _docstrings.get_field_docstring(cls, pd_field.name) + field_list.append( FieldDefinition.make( name=pd_field.name, @@ -422,7 +426,7 @@ def _field_list_from_pydantic( default=MISSING_NONPROP if pd_field.required else pd_field.get_default(), - helptext=pd_field.field_info.description, + helptext=helptext, ) ) return field_list @@ -747,9 +751,11 @@ def _get_dataclass_field_default( return getattr(parent_default_instance, field.name) else: warnings.warn( - f"Could not find field {field.name} in default instance" - f" {parent_default_instance}, which has" - f" type {type(parent_default_instance)},", + ( + f"Could not find field {field.name} in default instance" + f" {parent_default_instance}, which has" + f" type {type(parent_default_instance)}," + ), stacklevel=2, ) diff --git a/tyro/_resolver.py b/tyro/_resolver.py index cf390ea6..506d36d4 100644 --- a/tyro/_resolver.py +++ b/tyro/_resolver.py @@ -28,7 +28,8 @@ def unwrap_origin_strip_extras(typ: TypeOrCallable) -> TypeOrCallable: - """Returns the origin, ignoring typing.Annotated, of typ if it exists. Otherwise, returns typ.""" + """Returns the origin, ignoring typing.Annotated, of typ if it exists. Otherwise, + returns typ.""" # TODO: Annotated[] handling should be revisited... typ, _ = unwrap_annotated(typ) origin = get_origin(typ)