Skip to content

Commit

Permalink
Add tyro.conf.arg(), support "help" metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 11, 2022
1 parent f024f92 commit cd1589a
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 69 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tyro"
version = "0.3.31"
version = "0.3.32"
description = "Strongly typed, zero-effort CLI interfaces"
authors = ["brentyi <[email protected]>"]
include = ["./tyro/**/*"]
Expand Down
67 changes: 66 additions & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Generic, TypeVar, Union
from typing import Any, Callable, Generic, TypeVar, Union

import pytest
from helptext_utils import get_helptext
Expand Down Expand Up @@ -462,3 +462,68 @@ def main(
return value + inner.a + inner.b

assert tyro.cli(main, args=["--value", "5"]) == 8


def test_suppressed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: tyro.conf.Suppress[str] = "7"

def main(x: Any = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext


def test_suppress_manual_fixed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: tyro.conf.SuppressFixed[tyro.conf.Fixed[str]] = "7"

def main(x: Any = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext


def test_suppress_auto_fixed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: Callable = lambda x: 5

def main(x: tyro.conf.SuppressFixed[Any] = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext


def test_argconf_help():
@dataclasses.dataclass
class Struct:
a: Annotated[
int, tyro.conf.arg(name="nice", help="Hello world", metavar="NUMBER")
] = 5
b: tyro.conf.Suppress[str] = "7"

def main(x: Any = Struct()) -> int:
return x.a

helptext = get_helptext(main)
assert "Hello world" in helptext
assert "INT" not in helptext
assert "NUMBER" in helptext
assert "--x.a" not in helptext
assert "--x.nice" in helptext
assert "--x.b" not in helptext

assert tyro.cli(main, args=[]) == 5
assert tyro.cli(main, args=["--x.nice", "3"]) == 3
44 changes: 0 additions & 44 deletions tests/test_helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from helptext_utils import get_helptext
from typing_extensions import Annotated, Literal

import tyro


def test_helptext():
@dataclasses.dataclass
Expand Down Expand Up @@ -566,45 +564,3 @@ def main2(x: Callable = nn.ReLU):
assert "--x {fixed}" in helptext
assert "(fixed to:" in helptext
assert "torch" in helptext


def test_suppressed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: tyro.conf.Suppress[str] = "7"

def main(x: Any = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext


def test_suppress_manual_fixed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: tyro.conf.SuppressFixed[tyro.conf.Fixed[str]] = "7"

def main(x: Any = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext


def test_suppress_auto_fixed():
@dataclasses.dataclass
class Struct:
a: int = 5
b: Callable = lambda x: 5

def main(x: tyro.conf.SuppressFixed[Any] = Struct()):
pass

helptext = get_helptext(main)
assert "--x.a" in helptext
assert "--x.b" not in helptext
10 changes: 10 additions & 0 deletions tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def add_argument(
# the field default to a string format, then back to the desired type.
kwargs["default"] = _fields.MISSING_NONPROP

# Apply overrides in our arg configuration object.
# Note that the `name` field is applied when the field object is instantiated!
kwargs.update(
{
k: v
for k, v in vars(self.field.argconf).items()
if v is not None and k != "name"
}
)

# Add argument! Note that the name must be passed in as a position argument.
arg = parser.add_argument(name_or_flag, **kwargs)

Expand Down
4 changes: 1 addition & 3 deletions tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
if field.is_positional():
args.append(value)
else:
kwargs[
field.name if field.name_override is None else field.name_override
] = value
kwargs[field.call_argname] = value

# Note: we unwrap types both before and after narrowing. This is because narrowing
# sometimes produces types like `Tuple[T1, T2, ...]`, where we actually want just
Expand Down
38 changes: 30 additions & 8 deletions tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from . import conf # Avoid circular import.
from . import _docstrings, _instantiators, _resolver, _singleton, _strings
from .conf import _markers
from .conf import _confstruct, _markers

# Support attrs and pydantic if they're installed.
try:
Expand All @@ -52,10 +52,12 @@ class FieldDefinition:
helptext: Optional[str]
markers: FrozenSet[_markers.Marker]

argconf: _confstruct._ArgConfiguration

# Override the name in our kwargs. Currently only used for dictionary types when
# the key values aren't strings, but in the future could be used whenever the
# user-facing argument name doesn't match the keyword expected by our callable.
name_override: Optional[Any]
call_argname: Any

def __post_init__(self):
if (
Expand All @@ -71,18 +73,27 @@ def make(
typ: Type,
default: Any,
helptext: Optional[str],
call_argname_override: Optional[Any] = None,
*,
markers: Tuple[_markers.Marker, ...] = (),
name_override: Optional[Any] = None,
):
# Try to extract argconf overrides from type.
_, argconfs = _resolver.unwrap_annotated(typ, _confstruct._ArgConfiguration)
if len(argconfs) == 0:
argconf = _confstruct._ArgConfiguration(None, None, None)
else:
assert len(argconfs) == 1
(argconf,) = argconfs

typ, inferred_markers = _resolver.unwrap_annotated(typ, _markers.Marker)
return FieldDefinition(
name,
name if argconf.name is None else argconf.name,
typ,
default,
helptext,
frozenset(inferred_markers).union(markers),
name_override,
argconf,
call_argname_override if call_argname_override is not None else name,
)

def add_markers(self, markers: Tuple[_markers.Marker, ...]) -> FieldDefinition:
Expand Down Expand Up @@ -201,7 +212,7 @@ def _try_field_list_from_callable(
default_instance: _DefaultInstance,
) -> Union[List[FieldDefinition], UnsupportedNestedTypeMessage]:
f, found_subcommand_configs = _resolver.unwrap_annotated(
f, conf._subcommands._SubcommandConfiguration
f, conf._confstruct._SubcommandConfiguration
)
if len(found_subcommand_configs) > 0:
default_instance = found_subcommand_configs[0].default
Expand Down Expand Up @@ -362,12 +373,23 @@ def _try_field_list_from_dataclass(
field_list = []
for dc_field in filter(lambda field: field.init, _resolver.resolved_fields(cls)):
default = _get_dataclass_field_default(dc_field, default_instance)

# Try to get helptext from field metadata. This is also intended to be
# compatible with HuggingFace-style config objects.
helptext = dc_field.metadata.get("help", None)
assert isinstance(helptext, (str, type(None)))

# Try to get helptext from docstrings. Note that this can't be generated
# dynamically.
if helptext is None:
helptext = _docstrings.get_field_docstring(cls, dc_field.name)

field_list.append(
FieldDefinition.make(
name=dc_field.name,
typ=dc_field.type,
default=default,
helptext=_docstrings.get_field_docstring(cls, dc_field.name),
helptext=helptext,
)
)
return field_list
Expand Down Expand Up @@ -542,7 +564,7 @@ def _try_field_list_from_dict(
default=v,
helptext=None,
# Dictionary specific key:
name_override=k,
call_argname_override=k,
)
)
return field_list
Expand Down
8 changes: 4 additions & 4 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_resolver,
_strings,
)
from .conf import _markers, _subcommands
from .conf import _confstruct, _markers

T = TypeVar("T")

Expand Down Expand Up @@ -294,14 +294,14 @@ def from_field(

# Get subcommand configurations from `tyro.conf.subcommand()`.
subcommand_config_from_name: Dict[
str, _subcommands._SubcommandConfiguration
str, _confstruct._SubcommandConfiguration
] = {}
subcommand_name_from_default_hash: Dict[int, str] = {}
subcommand_name_from_type: Dict[Type, str] = {} # Used for default matching.
for option in options_no_none:
subcommand_name = _strings.subparser_name_from_type(prefix, option)
option, found_subcommand_configs = _resolver.unwrap_annotated(
option, _subcommands._SubcommandConfiguration
option, _confstruct._SubcommandConfiguration
)
default_hash = None
if len(found_subcommand_configs) != 0:
Expand Down Expand Up @@ -378,7 +378,7 @@ def from_field(
if subcommand_name in subcommand_config_from_name:
subcommand_config = subcommand_config_from_name[subcommand_name]
else:
subcommand_config = _subcommands._SubcommandConfiguration(
subcommand_config = _confstruct._SubcommandConfiguration(
"unused",
description=None,
default=_fields.MISSING_NONPROP,
Expand Down
4 changes: 2 additions & 2 deletions tyro/_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def hyphen_separated_from_camel_case(name: str) -> str:


def _subparser_name_from_type(cls: Type) -> Tuple[str, bool]:
from .conf import _subcommands # Prevent circular imports
from .conf import _confstruct # Prevent circular imports

cls, type_from_typevar = _resolver.resolve_generic_types(cls)
cls, found_subcommand_configs = _resolver.unwrap_annotated(
cls, _subcommands._SubcommandConfiguration
cls, _confstruct._SubcommandConfiguration
)

# Subparser name from `tyro.metadata.subcommand()`.
Expand Down
5 changes: 3 additions & 2 deletions tyro/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Features here are supported, but generally unnecessary and should be used sparingly.
"""

from ._confstruct import arg, subcommand
from ._markers import (
AvoidSubcommands,
Fixed,
Expand All @@ -17,15 +18,15 @@
Suppress,
SuppressFixed,
)
from ._subcommands import subcommand

__all__ = [
"arg",
"subcommand",
"AvoidSubcommands",
"Fixed",
"FlagConversionOff",
"OmitSubcommandPrefixes",
"Positional",
"Suppress",
"SuppressFixed",
"subcommand",
]
26 changes: 25 additions & 1 deletion tyro/conf/_subcommands.py → tyro/conf/_confstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def subcommand(
prefix_name: bool = True,
) -> Any:
"""Returns a metadata object for configuring subcommands with `typing.Annotated`.
This is useful but can make code harder to read, so usage is discouraged.
Useful for aesthetics.
Consider the standard approach for creating subcommands:
Expand Down Expand Up @@ -53,3 +53,27 @@ def subcommand(
```
"""
return _SubcommandConfiguration(name, default, description, prefix_name)


@dataclasses.dataclass(frozen=True)
class _ArgConfiguration:
name: Optional[str]
metavar: Optional[str]
help: Optional[str]


def arg(
*,
name: Optional[str] = None,
metavar: Optional[str] = None,
help: Optional[str] = None,
) -> Any:
"""Returns a metadata object for configuring arguments with `typing.Annotated`.
Useful for aesthetics.
Usage:
```python
x: Annotated[int, tyro.conf.arg(...)]
```
"""
return _ArgConfiguration(name=name, metavar=metavar, help=help)
6 changes: 3 additions & 3 deletions tyro/conf/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
If we have a structure with the field:
cmd: Union[Commit, Checkout]
cmd: Union[NestedTypeA, NestedTypeB]
By default, --cmd.branch may be generated as a flag for each dataclass in the union.
If subcommand prefixes are omitted, we would instead simply have --branch.
By default, `--cmd.arg` may be generated as a flag for each dataclass in the union.
If subcommand prefixes are omitted, we would instead simply have `--arg`.
"""


Expand Down

0 comments on commit cd1589a

Please sign in to comment.