Skip to content

Commit

Permalink
Add prefix_name field to tyro.conf.arg
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 5, 2023
1 parent a7de7e3 commit f4b9a70
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 25 deletions.
45 changes: 45 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,35 @@ def main(x: Any = Struct()) -> int:
assert tyro.cli(main, args=["--x.nice", "3"]) == 3


def test_argconf_no_prefix_help() -> None:
@dataclasses.dataclass
class Struct:
a: Annotated[
int,
tyro.conf.arg(
name="nice", help="Hello world", metavar="NUMBER", prefix_name=False
),
] = 5
b: tyro.conf.Suppress[str] = "7"

def main(x: Any = Struct()) -> int:
return x.a

helptext = get_helptext(main)
assert "Hello world" in helptext
assert "INT" not in helptext
assert "NUMBER" in helptext
assert "--x.a" not in helptext
assert "--x.nice" not in helptext
assert "--nice" in helptext
assert "--x.b" not in helptext

assert tyro.cli(main, args=[]) == 5
with pytest.raises(SystemExit):
assert tyro.cli(main, args=["--x.nice", "3"]) == 3
assert tyro.cli(main, args=["--nice", "3"]) == 3


def test_positional() -> None:
def main(x: tyro.conf.Positional[int], y: int) -> int:
return x + y
Expand Down Expand Up @@ -822,3 +851,19 @@ class A:
x={"0": {"1": 2, "3": 4}, "4": {"5": 6}}
)
assert tyro.cli(A, args=[]) == A(x={})


def test_duplicated_arg() -> None:
# Loosely inspired by: https://github.com/brentyi/tyro/issues/49
@dataclasses.dataclass
class ModelConfig:
num_slots: Annotated[int, tyro.conf.arg(name="num_slots", prefix_name=False)]

@dataclasses.dataclass
class TrainConfig:
num_slots: int
model: ModelConfig

assert tyro.cli(TrainConfig, args="--num-slots 3".split(" ")) == TrainConfig(
num_slots=3, model=ModelConfig(num_slots=3)
)
4 changes: 1 addition & 3 deletions tyro/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def set_accent_color(accent_color: Optional[str]) -> None:
THEME.helptext = Style(dim=True)
THEME.helptext_required = Style(color="bright_red", bold=True)
THEME.helptext_default = Style(
color="cyan"
if accent_color != "cyan"
else "magenta"
color="cyan" if accent_color != "cyan" else "magenta"
# Another option: make default color match accent color. This is maybe more
# visually consistent, but harder to read.
# color=accent_color if accent_color is not None else "cyan",
Expand Down
26 changes: 17 additions & 9 deletions tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def add_argument(
kwargs["metavar"] = self.field.argconf.metavar

# Add argument! Note that the name must be passed in as a position argument.
arg = parser.add_argument(name_or_flag, **kwargs)
try:
arg = parser.add_argument(name_or_flag, **kwargs)
except argparse.ArgumentError:
return

# Do our best to tab complete paths.
# There will be false positives here, but if choices is unset they should be
Expand Down Expand Up @@ -452,14 +455,15 @@ def _rule_set_name_or_flag_and_dest(
arg: ArgumentDefinition,
lowered: LoweredArgumentDefinition,
) -> LoweredArgumentDefinition:
# Positional arguments: no -- prefix.
if arg.field.is_positional():
name_or_flag = _strings.make_field_name([arg.name_prefix, arg.field.name])
name_or_flag = _strings.make_field_name(
[arg.name_prefix, arg.field.name]
if arg.field.argconf.prefix_name
else [arg.field.name]
)

# Prefix keyword arguments with --.
else:
name_or_flag = "--" + _strings.make_field_name(
[arg.name_prefix, arg.field.name]
)
if not arg.field.is_positional():
name_or_flag = "--" + name_or_flag

# Strip.
if name_or_flag.startswith("--") and arg.subcommand_prefix != "":
Expand All @@ -472,7 +476,11 @@ def _rule_set_name_or_flag_and_dest(
return dataclasses.replace(
lowered,
name_or_flag=name_or_flag,
dest=_strings.make_field_name([arg.dest_prefix, arg.field.name]),
dest=(
_strings.make_field_name([arg.dest_prefix, arg.field.name])
if arg.field.argconf.prefix_name
else arg.field.name
),
)


Expand Down
13 changes: 10 additions & 3 deletions tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def call_from_args(
def get_value_from_arg(prefixed_field_name: str) -> Any:
"""Helper for getting values from `value_from_arg` + doing some extra
asserts."""
assert prefixed_field_name in value_from_prefixed_field_name
assert (
prefixed_field_name in value_from_prefixed_field_name
), f"{prefixed_field_name} not in {value_from_prefixed_field_name}"
return value_from_prefixed_field_name[prefixed_field_name]

arg_from_prefixed_field_name: Dict[str, _arguments.ArgumentDefinition] = {}
Expand All @@ -63,9 +65,14 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:

# Standard arguments.
arg = arg_from_prefixed_field_name[prefixed_field_name]
consumed_keywords.add(prefixed_field_name)
name_maybe_prefixed = (
prefixed_field_name
if field.argconf.prefix_name
else _strings.make_field_name([field.name])
)
consumed_keywords.add(name_maybe_prefixed)
if not arg.lowered.is_fixed():
value = get_value_from_arg(prefixed_field_name)
value = get_value_from_arg(name_maybe_prefixed)

if value in _fields.MISSING_SINGLETONS:
value = arg.field.default
Expand Down
2 changes: 1 addition & 1 deletion tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def make(
# Try to extract argconf overrides from type.
_, argconfs = _resolver.unwrap_annotated(typ, _confstruct._ArgConfiguration)
if len(argconfs) == 0:
argconf = _confstruct._ArgConfiguration(None, None, None)
argconf = _confstruct._ArgConfiguration(None, None, None, True)
else:
assert len(argconfs) == 1
(argconf,) = argconfs
Expand Down
3 changes: 2 additions & 1 deletion tyro/_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def union_instantiator(strings: List[str]) -> Any:
)
raise ValueError(
f"no type in {options} could be instantiated from"
f" {strings}.\n\nGot errors: \n- " + "\n- ".join(errors)
f" {strings}.\n\nGot errors: \n- "
+ "\n- ".join(errors)
)

return union_instantiator, InstantiatorMetadata(
Expand Down
12 changes: 6 additions & 6 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def from_callable_or_type(
# Don't make a subparser.
field = dataclasses.replace(field, typ=type(field.default))
else:
subparsers_from_prefix[
subparsers_attempt.prefix
] = subparsers_attempt
subparsers_from_prefix[subparsers_attempt.prefix] = (
subparsers_attempt
)
subparsers = add_subparsers_to_leaves(
subparsers, subparsers_attempt
)
Expand Down Expand Up @@ -345,9 +345,9 @@ def from_field(
return None

# Get subcommand configurations from `tyro.conf.subcommand()`.
subcommand_config_from_name: Dict[
str, _confstruct._SubcommandConfiguration
] = {}
subcommand_config_from_name: Dict[str, _confstruct._SubcommandConfiguration] = (
{}
)
subcommand_type_from_name: Dict[str, type] = {}
for option in options_no_none:
subcommand_name = _strings.subparser_name_from_type(prefix, option)
Expand Down
10 changes: 8 additions & 2 deletions tyro/conf/_confstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ class _ArgConfiguration:
name: Optional[str]
metavar: Optional[str]
help: Optional[str]
# TODO - add prefix_name: bool
prefix_name: bool


def arg(
*,
name: Optional[str] = None,
metavar: Optional[str] = None,
help: Optional[str] = None,
prefix_name: bool = True,
) -> Any:
"""Returns a metadata object for configuring arguments with `typing.Annotated`.
Useful for aesthetics.
Expand All @@ -77,4 +78,9 @@ def arg(
x: Annotated[int, tyro.conf.arg(...)]
```
"""
return _ArgConfiguration(name=name, metavar=metavar, help=help)
return _ArgConfiguration(
name=name,
metavar=metavar,
help=help,
prefix_name=prefix_name,
)

0 comments on commit f4b9a70

Please sign in to comment.