Skip to content

Commit

Permalink
Some TLC for helptext generation
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Apr 22, 2022
1 parent 6d81978 commit 38f8e88
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 32 deletions.
9 changes: 4 additions & 5 deletions dcargs/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ArgumentDefinition:
default: Optional[Any]

# Fields that will be handled by argument transformations.
required: Optional[bool] = None
required: bool = False
action: Optional[str] = None
nargs: Optional[Union[int, str]] = None
choices: Optional[Set[Any]] = None
Expand Down Expand Up @@ -103,11 +103,10 @@ def make_from_field(
def _transform_required_if_default_set(arg: ArgumentDefinition) -> ArgumentDefinition:
"""Set `required=True` if a default value is set."""

# Don't set if default is set, or if required flag is already set.
if arg.default is not None:
return dataclasses.replace(arg, required=False)
else:
# Mark arg as required if a default is set.
if arg.default is None:
return dataclasses.replace(arg, required=True)
return arg


def _transform_handle_boolean_flags(arg: ArgumentDefinition) -> ArgumentDefinition:
Expand Down
5 changes: 4 additions & 1 deletion dcargs/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def get_value_from_arg(arg: str) -> Any:
raise FieldActionValueError(
f"Parsing error for {arg.get_flag()}: {e.args[0]}"
)
elif prefixed_field_name in parser_definition.nested_dataclass_field_names:
elif (
prefixed_field_name
in parser_definition.helptext_from_nested_dataclass_field_name
):
# Nested dataclasses.
value, consumed_keywords_child = construct_dataclass(
field_type,
Expand Down
81 changes: 65 additions & 16 deletions dcargs/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import argparse
import dataclasses
import shutil
import warnings
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union

import termcolor
from typing_extensions import get_args, get_origin

from . import _arguments, _docstrings, _instantiators, _resolver, _strings
Expand Down Expand Up @@ -56,22 +58,64 @@ class ParserSpecification:

cls: Type
args: List[_arguments.ArgumentDefinition]
nested_dataclass_field_names: List[str]
helptext_from_nested_dataclass_field_name: Dict[str, Optional[str]]
subparsers: Optional["SubparsersSpecification"]

def apply(self, parser: argparse.ArgumentParser) -> None:
"""Create defined arguments and subparsers."""

# Put required group at start of group list.
required_group = parser.add_argument_group("required arguments")
def format_group_name(nested_field_name: str, required: bool) -> str:
if required:
prefix = termcolor.colored("required", attrs=["bold"])
else:
prefix = termcolor.colored("optional", attrs=["bold", "dark"])
suffix = " arguments"
if nested_field_name == "":
suffix = suffix[1:]

return (
prefix
+ " "
+ nested_field_name.replace("_", " ").replace(".", " • ")
+ suffix
)

optional_group_from_prefix: Dict[str, argparse._ArgumentGroup] = {
"": parser._action_groups[1],
}
required_group_from_prefix: Dict[str, argparse._ArgumentGroup] = {
"": parser.add_argument_group(format_group_name("", required=True)),
}

# Break some API boundaries to rename the optional group, and
parser._action_groups[1].title = format_group_name("", required=False)
parser._action_groups = parser._action_groups[::-1]

# Add each argument.
for arg in self.args:
if arg.required:
arg.add_argument(required_group)
target_groups, other_groups = (
required_group_from_prefix,
optional_group_from_prefix,
)
else:
arg.add_argument(parser)
target_groups, other_groups = (
optional_group_from_prefix,
required_group_from_prefix,
)

if arg.prefix not in target_groups:
nested_field_name = arg.prefix[:-1]
target_groups[arg.prefix] = parser.add_argument_group(
format_group_name(nested_field_name, required=arg.required),
# Add a description, but only to the first group for a field.
description=self.helptext_from_nested_dataclass_field_name[
nested_field_name
]
if arg.prefix not in other_groups
else None,
)
arg.add_argument(target_groups[arg.prefix])

# Add subparsers.
if self.subparsers is not None:
Expand Down Expand Up @@ -123,7 +167,7 @@ def from_dataclass(
parent_dataclasses = parent_dataclasses | {cls}

args = []
nested_dataclass_field_names = []
helptext_from_nested_dataclass_field_name = {}
subparsers = None
for field in _resolver.resolved_fields(cls): # type: ignore

Expand All @@ -145,7 +189,8 @@ def from_dataclass(
if subparsers_out is not None:
if subparsers is not None:
raise _instantiators.UnsupportedTypeAnnotationError(
"Only one subparser (union over dataclasses) is allowed per class."
"Only one subparser (union over dataclasses) is allowed per"
" class."
)

subparsers = subparsers_out
Expand All @@ -156,8 +201,12 @@ def from_dataclass(
if nested_out is not None:
child_args, child_nested_field_names = nested_out
args.extend(child_args)
nested_dataclass_field_names.extend(child_nested_field_names)
nested_dataclass_field_names.append(field.name)
helptext_from_nested_dataclass_field_name.update(
child_nested_field_names
)
helptext_from_nested_dataclass_field_name[
field.name
] = _docstrings.get_field_docstring(cls, field.name)
continue

# Handle simple fields!
Expand All @@ -180,7 +229,7 @@ def from_dataclass(
return ParserSpecification(
cls=cls,
args=args,
nested_dataclass_field_names=nested_dataclass_field_names,
helptext_from_nested_dataclass_field_name=helptext_from_nested_dataclass_field_name,
subparsers=subparsers,
)

Expand Down Expand Up @@ -253,7 +302,7 @@ def handle_unions_over_dataclasses(

def handle_nested_dataclasses(
self,
) -> Optional[Tuple[List[_arguments.ArgumentDefinition], List[str]]]:
) -> Optional[Tuple[List[_arguments.ArgumentDefinition], Dict[str, Optional[str]]]]:
"""Handle nested dataclasses. Returns `None` if not applicable."""
# Resolve field type
field_type = (
Expand Down Expand Up @@ -281,9 +330,9 @@ def handle_nested_dataclasses(
+ arg.prefix,
)

nested_dataclass_field_names = [
self.field.name + _strings.NESTED_DATACLASS_DELIMETER + x
for x in child_definition.nested_dataclass_field_names
]
helptext_from_nested_dataclass_field_name = {
self.field.name + _strings.NESTED_DATACLASS_DELIMETER + x: y
for x, y in child_definition.helptext_from_nested_dataclass_field_name.items()
}

return child_args, nested_dataclass_field_names
return child_args, helptext_from_nested_dataclass_field_name
17 changes: 10 additions & 7 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""An argument parsing example.
Note that there are multiple possible ways to document dataclass attributes, all
of which are supported by the automatic helptext generator.
Note that multiple possible documentation styles are supported by the field helptext
generator; we could also use docstring-style triple quote comments, or #-style comments
on the same line.
"""

import dataclasses
Expand All @@ -17,8 +18,8 @@ class OptimizerType(enum.Enum):

@dataclasses.dataclass(frozen=True)
class OptimizerConfig:
# Variant of SGD to use.
type: OptimizerType
# Gradient-based optimizer to use.
algorithm: OptimizerType = OptimizerType.ADAM

# Learning rate to use.
learning_rate: float = 3e-4
Expand All @@ -33,13 +34,15 @@ class ExperimentConfig:
pulled from this docstring by default, but can also be overrided with
`dcargs.parse`'s `description=` flag."""

experiment_name: str # Experiment name to use.
# Experiment name to use.
experiment_name: str

# Various configurable options for our optimizer.
optimizer: OptimizerConfig

# Random seed. This is helpful for making sure that our experiments are all
# reproducible!
seed: int = 0
"""Random seed. This is helpful for making sure that our experiments are
all reproducible!"""


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="dcargs",
version="0.0.19",
version="0.0.20",
description="Portable, reusable, strongly typed CLIs from dataclass definitions",
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -16,7 +16,7 @@
packages=find_packages(),
package_data={"dcargs": ["py.typed"]},
python_requires=">=3.7",
install_requires=["typing_extensions>=4.0.0", "pyyaml"],
install_requires=["typing_extensions>=4.0.0", "pyyaml", "termcolor"],
extras_require={
"testing": [
"pytest",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Helptext:
dcargs.parse(Helptext, args=["--help"])
helptext = f.getvalue()
assert Helptext.__doc__ in helptext
assert "required arguments:\n --x INT Documentation 1\n" in helptext
assert ":\n --x INT Documentation 1\n" in helptext
assert "--y INT Documentation 2\n" in helptext
assert "--z INT Documentation 3 (default: 3)\n" in helptext

Expand Down

0 comments on commit 38f8e88

Please sign in to comment.