diff --git a/tests/test_conf.py b/tests/test_conf.py index 99f34853..7c91a688 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -530,6 +530,35 @@ def main(x: Any = Struct()) -> int: assert tyro.cli(main, args=["--x.nice", "3"]) == 3 +def test_argconf_no_prefix_help() -> None: + @dataclasses.dataclass + class Struct: + a: Annotated[ + int, + tyro.conf.arg( + name="nice", help="Hello world", metavar="NUMBER", prefix_name=False + ), + ] = 5 + b: tyro.conf.Suppress[str] = "7" + + def main(x: Any = Struct()) -> int: + return x.a + + helptext = get_helptext(main) + assert "Hello world" in helptext + assert "INT" not in helptext + assert "NUMBER" in helptext + assert "--x.a" not in helptext + assert "--x.nice" not in helptext + assert "--nice" in helptext + assert "--x.b" not in helptext + + assert tyro.cli(main, args=[]) == 5 + with pytest.raises(SystemExit): + assert tyro.cli(main, args=["--x.nice", "3"]) == 3 + assert tyro.cli(main, args=["--nice", "3"]) == 3 + + def test_positional() -> None: def main(x: tyro.conf.Positional[int], y: int) -> int: return x + y @@ -822,3 +851,19 @@ class A: x={"0": {"1": 2, "3": 4}, "4": {"5": 6}} ) assert tyro.cli(A, args=[]) == A(x={}) + + +def test_duplicated_arg() -> None: + # Loosely inspired by: https://github.com/brentyi/tyro/issues/49 + @dataclasses.dataclass + class ModelConfig: + num_slots: Annotated[int, tyro.conf.arg(name="num_slots", prefix_name=False)] + + @dataclasses.dataclass + class TrainConfig: + num_slots: int + model: ModelConfig + + assert tyro.cli(TrainConfig, args="--num-slots 3".split(" ")) == TrainConfig( + num_slots=3, model=ModelConfig(num_slots=3) + ) diff --git a/tyro/_argparse_formatter.py b/tyro/_argparse_formatter.py index 874947a0..eb9f1978 100644 --- a/tyro/_argparse_formatter.py +++ b/tyro/_argparse_formatter.py @@ -56,9 +56,7 @@ def set_accent_color(accent_color: Optional[str]) -> None: THEME.helptext = Style(dim=True) THEME.helptext_required = Style(color="bright_red", bold=True) THEME.helptext_default = Style( - color="cyan" - if accent_color != "cyan" - else "magenta" + color="cyan" if accent_color != "cyan" else "magenta" # Another option: make default color match accent color. This is maybe more # visually consistent, but harder to read. # color=accent_color if accent_color is not None else "cyan", diff --git a/tyro/_arguments.py b/tyro/_arguments.py index e0b47f6a..ee3ab753 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -136,7 +136,10 @@ def add_argument( kwargs["metavar"] = self.field.argconf.metavar # Add argument! Note that the name must be passed in as a position argument. - arg = parser.add_argument(name_or_flag, **kwargs) + try: + arg = parser.add_argument(name_or_flag, **kwargs) + except argparse.ArgumentError: + return # Do our best to tab complete paths. # There will be false positives here, but if choices is unset they should be @@ -452,14 +455,15 @@ def _rule_set_name_or_flag_and_dest( arg: ArgumentDefinition, lowered: LoweredArgumentDefinition, ) -> LoweredArgumentDefinition: - # Positional arguments: no -- prefix. - if arg.field.is_positional(): - name_or_flag = _strings.make_field_name([arg.name_prefix, arg.field.name]) + name_or_flag = _strings.make_field_name( + [arg.name_prefix, arg.field.name] + if arg.field.argconf.prefix_name + else [arg.field.name] + ) + # Prefix keyword arguments with --. - else: - name_or_flag = "--" + _strings.make_field_name( - [arg.name_prefix, arg.field.name] - ) + if not arg.field.is_positional(): + name_or_flag = "--" + name_or_flag # Strip. if name_or_flag.startswith("--") and arg.subcommand_prefix != "": @@ -472,7 +476,11 @@ def _rule_set_name_or_flag_and_dest( return dataclasses.replace( lowered, name_or_flag=name_or_flag, - dest=_strings.make_field_name([arg.dest_prefix, arg.field.name]), + dest=( + _strings.make_field_name([arg.dest_prefix, arg.field.name]) + if arg.field.argconf.prefix_name + else arg.field.name + ), ) diff --git a/tyro/_calling.py b/tyro/_calling.py index e38dc386..ddb2dd7e 100644 --- a/tyro/_calling.py +++ b/tyro/_calling.py @@ -42,7 +42,9 @@ def call_from_args( def get_value_from_arg(prefixed_field_name: str) -> Any: """Helper for getting values from `value_from_arg` + doing some extra asserts.""" - assert prefixed_field_name in value_from_prefixed_field_name + assert ( + prefixed_field_name in value_from_prefixed_field_name + ), f"{prefixed_field_name} not in {value_from_prefixed_field_name}" return value_from_prefixed_field_name[prefixed_field_name] arg_from_prefixed_field_name: Dict[str, _arguments.ArgumentDefinition] = {} @@ -63,9 +65,14 @@ def get_value_from_arg(prefixed_field_name: str) -> Any: # Standard arguments. arg = arg_from_prefixed_field_name[prefixed_field_name] - consumed_keywords.add(prefixed_field_name) + name_maybe_prefixed = ( + prefixed_field_name + if field.argconf.prefix_name + else _strings.make_field_name([field.name]) + ) + consumed_keywords.add(name_maybe_prefixed) if not arg.lowered.is_fixed(): - value = get_value_from_arg(prefixed_field_name) + value = get_value_from_arg(name_maybe_prefixed) if value in _fields.MISSING_SINGLETONS: value = arg.field.default diff --git a/tyro/_fields.py b/tyro/_fields.py index f4df8f1e..198a02c9 100644 --- a/tyro/_fields.py +++ b/tyro/_fields.py @@ -80,7 +80,7 @@ def make( # Try to extract argconf overrides from type. _, argconfs = _resolver.unwrap_annotated(typ, _confstruct._ArgConfiguration) if len(argconfs) == 0: - argconf = _confstruct._ArgConfiguration(None, None, None) + argconf = _confstruct._ArgConfiguration(None, None, None, True) else: assert len(argconfs) == 1 (argconf,) = argconfs diff --git a/tyro/_instantiators.py b/tyro/_instantiators.py index 65989f23..2fddb623 100644 --- a/tyro/_instantiators.py +++ b/tyro/_instantiators.py @@ -486,7 +486,8 @@ def union_instantiator(strings: List[str]) -> Any: ) raise ValueError( f"no type in {options} could be instantiated from" - f" {strings}.\n\nGot errors: \n- " + "\n- ".join(errors) + f" {strings}.\n\nGot errors: \n- " + + "\n- ".join(errors) ) return union_instantiator, InstantiatorMetadata( diff --git a/tyro/_parsers.py b/tyro/_parsers.py index b56f4d93..727add96 100644 --- a/tyro/_parsers.py +++ b/tyro/_parsers.py @@ -126,9 +126,9 @@ def from_callable_or_type( # Don't make a subparser. field = dataclasses.replace(field, typ=type(field.default)) else: - subparsers_from_prefix[ - subparsers_attempt.prefix - ] = subparsers_attempt + subparsers_from_prefix[subparsers_attempt.prefix] = ( + subparsers_attempt + ) subparsers = add_subparsers_to_leaves( subparsers, subparsers_attempt ) @@ -345,9 +345,9 @@ def from_field( return None # Get subcommand configurations from `tyro.conf.subcommand()`. - subcommand_config_from_name: Dict[ - str, _confstruct._SubcommandConfiguration - ] = {} + subcommand_config_from_name: Dict[str, _confstruct._SubcommandConfiguration] = ( + {} + ) subcommand_type_from_name: Dict[str, type] = {} for option in options_no_none: subcommand_name = _strings.subparser_name_from_type(prefix, option) diff --git a/tyro/conf/_confstruct.py b/tyro/conf/_confstruct.py index 887587de..1cf4f045 100644 --- a/tyro/conf/_confstruct.py +++ b/tyro/conf/_confstruct.py @@ -60,7 +60,7 @@ class _ArgConfiguration: name: Optional[str] metavar: Optional[str] help: Optional[str] - # TODO - add prefix_name: bool + prefix_name: bool def arg( @@ -68,6 +68,7 @@ def arg( name: Optional[str] = None, metavar: Optional[str] = None, help: Optional[str] = None, + prefix_name: bool = True, ) -> Any: """Returns a metadata object for configuring arguments with `typing.Annotated`. Useful for aesthetics. @@ -77,4 +78,9 @@ def arg( x: Annotated[int, tyro.conf.arg(...)] ``` """ - return _ArgConfiguration(name=name, metavar=metavar, help=help) + return _ArgConfiguration( + name=name, + metavar=metavar, + help=help, + prefix_name=prefix_name, + )