Skip to content

Commit

Permalink
Minor pydantic helptext improvements + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 26, 2023
1 parent ff146ed commit 4ee8be0
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 10 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tyro"
version = "0.3.37"
version = "0.3.38"
description = "Strongly typed, zero-effort CLI interfaces"
authors = ["brentyi <[email protected]>"]
include = ["./tyro/**/*"]
Expand Down Expand Up @@ -76,3 +76,8 @@ exclude_lines = [
# or anything that's deprecated
"deprecated",
]

[tool.ruff]
ignore = [
"E501", # Ignore line length errors.
]
6 changes: 4 additions & 2 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Callable, Generic, TypeVar, Union
from typing import Any, Generic, TypeVar, Union

import pytest
from helptext_utils import get_helptext
Expand Down Expand Up @@ -496,7 +496,9 @@ def test_suppress_auto_fixed() -> None:
@dataclasses.dataclass
class Struct:
a: int = 5
b: Callable = lambda x: 5

def b(x):
return 5

def main(x: tyro.conf.SuppressFixed[Any] = Struct()):
pass
Expand Down
49 changes: 49 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,52 @@ class Helptext(BaseModel):
assert "Documentation 1" in helptext
assert "Documentation 2" in helptext
assert "Documentation 3" in helptext


def test_pydantic_suppress_base_model_helptext() -> None:
class Helptext(BaseModel):
x: int = Field(description="Documentation 1")

y: int = Field(description="Documentation 2")

z: int = Field(description="Documentation 3")

f = io.StringIO()
with pytest.raises(SystemExit):
with contextlib.redirect_stdout(f):
tyro.cli(Helptext, args=["--help"])
helptext = f.getvalue()

assert "Create a new model by parsing and validating" not in helptext
assert "Documentation 1" in helptext
assert "Documentation 2" in helptext
assert "Documentation 3" in helptext


class HelptextWithFieldDocstring(BaseModel):
"""This docstring should be printed as a description."""

x: int
"""Documentation 1"""

y: int = Field(description="Documentation 2")

z: int = Field(description="Documentation 3")


def test_pydantic_field_helptext_from_docstring() -> None:
f = io.StringIO()
with pytest.raises(SystemExit):
with contextlib.redirect_stdout(f):
tyro.cli(HelptextWithFieldDocstring, args=["--help"])
helptext = f.getvalue()
assert (
tyro._strings.strip_ansi_sequences(
cast(str, HelptextWithFieldDocstring.__doc__)
)
in helptext
)

assert "Documentation 1" in helptext
assert "Documentation 2" in helptext
assert "Documentation 3" in helptext
6 changes: 4 additions & 2 deletions tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ def _cli_impl(
if deprecated_kwargs.get("avoid_subparsers", False):
f = conf.AvoidSubcommands[f] # type: ignore
warnings.warn(
"`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubparsers[]`"
" instead.",
(
"`avoid_subparsers=` is deprecated! use `tyro.conf.AvoidSubparsers[]`"
" instead."
),
stacklevel=2,
)

Expand Down
8 changes: 8 additions & 0 deletions tyro/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,22 @@ def get_callable_description(f: Callable) -> str:
if isinstance(f, functools.partial):
f = f.func

try:
import pydantic
except ImportError:
pydantic = None # type: ignore

# Note inspect.getdoc() causes some corner cases with TypedDicts.
docstring = f.__doc__
if (
docstring is None
and isinstance(f, type)
# Ignore TypedDict's __init__ docstring, because it will just be `dict`
and not is_typeddict(f)
# Ignore NamedTuple __init__ docstring.
and not _resolver.is_namedtuple(f)
# Ignore pydantic base model constructor docstring.
and not (pydantic is not None and f.__init__ is pydantic.BaseModel.__init__) # type: ignore
):
docstring = f.__init__.__doc__ # type: ignore
if docstring is None:
Expand Down
14 changes: 10 additions & 4 deletions tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,18 @@ def _field_list_from_pydantic(
# Handle pydantic models.
field_list = []
for pd_field in cls.__fields__.values(): # type: ignore
helptext = pd_field.field_info.description
if helptext is None:
helptext = _docstrings.get_field_docstring(cls, pd_field.name)

field_list.append(
FieldDefinition.make(
name=pd_field.name,
typ=pd_field.outer_type_,
default=MISSING_NONPROP
if pd_field.required
else pd_field.get_default(),
helptext=pd_field.field_info.description,
helptext=helptext,
)
)
return field_list
Expand Down Expand Up @@ -747,9 +751,11 @@ def _get_dataclass_field_default(
return getattr(parent_default_instance, field.name)
else:
warnings.warn(
f"Could not find field {field.name} in default instance"
f" {parent_default_instance}, which has"
f" type {type(parent_default_instance)},",
(
f"Could not find field {field.name} in default instance"
f" {parent_default_instance}, which has"
f" type {type(parent_default_instance)},"
),
stacklevel=2,
)

Expand Down
3 changes: 2 additions & 1 deletion tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@


def unwrap_origin_strip_extras(typ: TypeOrCallable) -> TypeOrCallable:
"""Returns the origin, ignoring typing.Annotated, of typ if it exists. Otherwise, returns typ."""
"""Returns the origin, ignoring typing.Annotated, of typ if it exists. Otherwise,
returns typ."""
# TODO: Annotated[] handling should be revisited...
typ, _ = unwrap_annotated(typ)
origin = get_origin(typ)
Expand Down

0 comments on commit 4ee8be0

Please sign in to comment.