diff --git a/examples/04_additional/06_conf.py b/examples/04_additional/06_conf.py index 7d7ce518..05d6ca27 100644 --- a/examples/04_additional/06_conf.py +++ b/examples/04_additional/06_conf.py @@ -7,6 +7,7 @@ Usage: `python ./06_conf.py --help` +`python ./06_conf.py 5 --boolean True` """ import dataclasses @@ -34,15 +35,25 @@ class CommitArgs: @dataclasses.dataclass class Args: - # A boolean field with flag conversion turned off. - boolean: tyro.conf.FlagConversionOff[bool] = False - # A numeric field parsed as a positional argument. positional: tyro.conf.Positional[int] = 3 + # A boolean field with flag conversion turned off. + boolean: tyro.conf.FlagConversionOff[bool] = False + # A numeric field that can't be changed via the CLI. fixed: tyro.conf.Fixed[int] = 5 + # A field with manually overridden properties. + manual: Annotated[ + str, + tyro.conf.arg( + name="renamed", + metavar="STRING", + help="A field with manually overridden properties!", + ), + ] = "Hello" + # A union over nested structures, but without subcommand generation. When a default # is provided, the type is simply fixed to that default. union_without_subcommand: tyro.conf.AvoidSubcommands[ diff --git a/pyproject.toml b/pyproject.toml index abe431b0..7343a987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tyro" -version = "0.3.32" +version = "0.3.33" description = "Strongly typed, zero-effort CLI interfaces" authors = ["brentyi "] include = ["./tyro/**/*"] diff --git a/tests/test_conf.py b/tests/test_conf.py index 1a8b7d55..fd0eb0cd 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -527,3 +527,19 @@ def main(x: Any = Struct()) -> int: assert tyro.cli(main, args=[]) == 5 assert tyro.cli(main, args=["--x.nice", "3"]) == 3 + + +def test_positional(): + def main(x: tyro.conf.Positional[int], y: int) -> int: + return x + y + + assert tyro.cli(main, args="5 --y 3".split(" ")) == 8 + assert tyro.cli(main, args="--y 3 5".split(" ")) == 8 + + +def test_positional_order_swap(): + def main(x: int, y: tyro.conf.Positional[int]) -> int: + return x + y + + assert tyro.cli(main, args="5 --x 3".split(" ")) == 8 + assert tyro.cli(main, args="--x 3 5".split(" ")) == 8 diff --git a/tyro/_arguments.py b/tyro/_arguments.py index a8284894..f381f302 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -65,13 +65,8 @@ def add_argument( # Apply overrides in our arg configuration object. # Note that the `name` field is applied when the field object is instantiated! - kwargs.update( - { - k: v - for k, v in vars(self.field.argconf).items() - if v is not None and k != "name" - } - ) + if self.field.argconf.metavar is not None: + kwargs["metavar"] = self.field.argconf.metavar # Add argument! Note that the name must be passed in as a position argument. arg = parser.add_argument(name_or_flag, **kwargs) @@ -314,13 +309,13 @@ def _rule_generate_helptext( help_parts = [] - docstring_help = arg.field.helptext + primary_help = arg.field.helptext - if docstring_help is not None and docstring_help != "": + if primary_help is not None and primary_help != "": # Note that the percent symbol needs some extra handling in argparse. # https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string - docstring_help = docstring_help.replace("%", "%%") - help_parts.append(_rich_tag_if_enabled(docstring_help, "helptext")) + primary_help = primary_help.replace("%", "%%") + help_parts.append(_rich_tag_if_enabled(primary_help, "helptext")) default = lowered.default if lowered.is_fixed(): diff --git a/tyro/_calling.py b/tyro/_calling.py index a71fff45..36b66024 100644 --- a/tyro/_calling.py +++ b/tyro/_calling.py @@ -163,7 +163,7 @@ def get_value_from_arg(prefixed_field_name: str) -> Any: consumed_keywords |= consumed_keywords_child if value is not _fields.EXCLUDE_FROM_CALL: - if field.is_positional(): + if field.is_positional_call(): args.append(value) else: kwargs[field.call_argname] = value diff --git a/tyro/_fields.py b/tyro/_fields.py index 3af08078..f96d7367 100644 --- a/tyro/_fields.py +++ b/tyro/_fields.py @@ -83,6 +83,7 @@ def make( else: assert len(argconfs) == 1 (argconf,) = argconfs + helptext = argconf.help typ, inferred_markers = _resolver.unwrap_annotated(typ, _markers.Marker) return FieldDefinition( @@ -102,6 +103,7 @@ def add_markers(self, markers: Tuple[_markers.Marker, ...]) -> FieldDefinition: ) def is_positional(self) -> bool: + """Returns True if the argument should be positional in the commandline.""" return ( # Explicit positionals. _markers.Positional in self.markers @@ -109,6 +111,15 @@ def is_positional(self) -> bool: or self.name == _strings.dummy_field_name ) + def is_positional_call(self) -> bool: + """Returns True if the argument should be positional in underlying Python call.""" + return ( + # Explicit positionals. + _markers._PositionalCall in self.markers + # Dummy dataclasses should have a single positional field. + or self.name == _strings.dummy_field_name + ) + class PropagatingMissingType(_singleton.Singleton): pass @@ -646,7 +657,7 @@ def _field_list_from_params( typ=hints[param.name], default=default, helptext=helptext, - markers=(_markers.Positional,) + markers=(_markers.Positional, _markers._PositionalCall) if param.kind is inspect.Parameter.POSITIONAL_ONLY else (), ) diff --git a/tyro/conf/_markers.py b/tyro/conf/_markers.py index 9bd6c649..d267ee43 100644 --- a/tyro/conf/_markers.py +++ b/tyro/conf/_markers.py @@ -18,6 +18,10 @@ """A type `T` can be annotated as `Positional[T]` if we want to parse it as a positional argument.""" +# Private marker. For when an argument is not only positional in the CLI, but also in +# the callable. +_PositionalCall = Annotated[T, None] + # TODO: the verb tenses here are inconsistent, naming could be revisited. # Perhaps Suppress should be Suppressed? But SuppressedFixed would be weird.