From 6d81978e4e06a6aeff91f0ca51f0d86ab7542e11 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 21 Apr 2022 10:00:59 -0700 Subject: [PATCH] More robust redesign for argument instantiation --- dcargs/_arguments.py | 442 +++++++++----------------------------- dcargs/_construction.py | 7 +- dcargs/_docstrings.py | 2 + dcargs/_instantiators.py | 245 +++++++++++++++++++++ dcargs/_parsers.py | 38 ++-- dcargs/_resolver.py | 2 + dcargs/_serialization.py | 2 + dcargs/_strings.py | 47 +++- examples/subparsers.py | 3 +- setup.py | 2 +- tests/test_collections.py | 52 ++++- tests/test_dcargs.py | 4 +- 12 files changed, 481 insertions(+), 365 deletions(-) create mode 100644 dcargs/_instantiators.py diff --git a/dcargs/_arguments.py b/dcargs/_arguments.py index bb478557..c51ffc69 100644 --- a/dcargs/_arguments.py +++ b/dcargs/_arguments.py @@ -1,50 +1,9 @@ import argparse -import collections.abc import dataclasses import enum -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from typing_extensions import Final, Literal, _AnnotatedAlias, get_args, get_origin - -from . import _docstrings, _strings - -T = TypeVar("T") - - -def instance_from_string(typ: Type, arg: str) -> T: - """Given a type and and a string from the command-line, reconstruct an object. Not - intended to deal with containers; these are handled in the argument - transformations. - - This is intended to replace all calls to `type(string)`, which can cause unexpected - behavior. As an example, note that the following argparse code will always print - `True`, because `bool("True") == bool("False") == bool("0") == True`. - ``` - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--flag", type=bool) - - print(parser.parse_args().flag) - ``` - """ - assert len(get_args(typ)) == 0, f"Type {typ} cannot be instantiated." - if typ is bool: - return _strings.bool_from_string(arg) # type: ignore - else: - return typ(arg) # type: ignore +from typing import Any, Dict, Optional, Set, Tuple, Type, TypeVar, Union + +from . import _docstrings, _instantiators @dataclasses.dataclass(frozen=True) @@ -61,18 +20,7 @@ class ArgumentDefinition: # Action that is called on parsed arguments. This handles conversions from strings # to our desired types. - # - # There are 3 options: - field_action: Union[ - # Most standard fields: these are converted from strings from the CLI. - Callable[[str], Any], - # Sequence fields! This should be used whenever argparse's `nargs` field is set. - Callable[[List[str]], Any], - # Special case: the only time that argparse doesn't give us a string is when the - # argument action is set to `store_true` or `store_false`. In this case, we get - # a bool directly, and the field action can be a no-op. - Callable[[bool], bool], - ] + instantiator: Optional[_instantiators.Instantiator] # Fields that will be populated initially. # Important: from here on out, all fields correspond 1:1 to inputs to argparse's @@ -113,7 +61,7 @@ def add_argument( kwargs.pop("field") kwargs.pop("parent_class") kwargs.pop("prefix") - kwargs.pop("field_action") + kwargs.pop("instantiator") kwargs.pop("name") # Note that the name must be passed in as a position argument. @@ -135,302 +83,126 @@ def make_from_field( assert field.init, "Field must be in class constructor" - # The default field action: this converts a string from argparse to the desired - # type of the argument. - def default_field_action(x: str) -> Any: - return instance_from_string(cast(Type, arg.type), x) - - # Create initial argument. arg = ArgumentDefinition( prefix="", field=field, parent_class=parent_class, - field_action=default_field_action, + instantiator=None, name=field.name, type=field.type, default=default, ) - - # Propagate argument through transforms until stable. - prev_arg = arg - argument_transforms = _get_argument_transforms(type_from_typevar) - while True: - for transform in argument_transforms: # type: ignore - # Apply transform. - arg = transform(arg) - - # Stability check. - if arg == prev_arg: - break - prev_arg = arg - + arg = _transform_required_if_default_set(arg) + arg = _transform_handle_boolean_flags(arg) + arg = _transform_recursive_instantiator_from_type(arg, type_from_typevar) + arg = _transform_generate_helptext(arg) + arg = _transform_convert_defaults_to_strings(arg) return arg -def _get_argument_transforms( - type_from_typevar: Dict[TypeVar, Type] -) -> List[Callable[[ArgumentDefinition], ArgumentDefinition]]: - """Get a list of argument transformations.""" +def _transform_required_if_default_set(arg: ArgumentDefinition) -> ArgumentDefinition: + """Set `required=True` if a default value is set.""" - def resolve_typevars(typ: Union[Type, TypeVar]) -> Type: - return type_from_typevar.get(cast(TypeVar, typ), cast(Type, typ)) + # Don't set if default is set, or if required flag is already set. + if arg.default is not None: + return dataclasses.replace(arg, required=False) + else: + return dataclasses.replace(arg, required=True) - # All transforms should start with `transform_`. - def transform_resolve_arg_typevars(arg: ArgumentDefinition) -> ArgumentDefinition: - if arg.type is not None: - return dataclasses.replace( - arg, - type=resolve_typevars(arg.type), - ) +def _transform_handle_boolean_flags(arg: ArgumentDefinition) -> ArgumentDefinition: + """""" + if arg.type is not bool: return arg - def transform_unwrap_final(arg: ArgumentDefinition) -> ArgumentDefinition: - """Treat Final[T] as just T.""" - if get_origin(arg.type) is Final: - (typ,) = get_args(arg.type) - return dataclasses.replace( - arg, - type=typ, - ) - else: - return arg - - def transform_unwrap_annotated(arg: ArgumentDefinition) -> ArgumentDefinition: - """Treat Annotated[T, annotation] as just T.""" - if hasattr(arg.type, "__class__") and arg.type.__class__ == _AnnotatedAlias: - typ = get_origin(arg.type) - return dataclasses.replace( - arg, - type=typ, - ) - else: - return arg - - def transform_handle_optionals(arg: ArgumentDefinition) -> ArgumentDefinition: - """Transform for handling Optional[T] types. Sets default to None and marks arg as - not required.""" - if get_origin(arg.type) is Union: - options = set(get_args(arg.type)) - assert ( - len(options) == 2 and type(None) in options - ), "Union must be either over dataclasses (for subparsers) or Optional" - (typ,) = options - {type(None)} - required = False - return dataclasses.replace( - arg, - type=typ, - required=required, - ) - else: - return arg - - def transform_required(arg: ArgumentDefinition) -> ArgumentDefinition: - """Set `required=True` if a default value is set.""" + if arg.default is None: + # If no default is passed in, we treat bools as a normal parameter. + return arg + elif arg.default is False: + # Default `False` => --flag passed in flips to `True`. + return dataclasses.replace( + arg, + action="store_true", + type=None, + instantiator=lambda x: x, # argparse will directly give us a bool! + ) + elif arg.default is True: + # Default `True` => --no-flag passed in flips to `False`. + return dataclasses.replace( + arg, + dest=arg.name, + name="no_" + arg.name, + action="store_false", + type=None, + instantiator=lambda x: x, # argparse will directly give us a bool! + ) + else: + assert False, "Invalid default" - # Don't set if default is set, or if required flag is already set. - if arg.default is not None or arg.required is not None: - return arg - return dataclasses.replace(arg, required=True) - def transform_booleans(arg: ArgumentDefinition) -> ArgumentDefinition: - """Set choices or actions for booleans.""" - if arg.type != bool or arg.choices is not None: - return arg - - if arg.default is None: - # If no default is passed in, the user must explicitly choose between `True` - # and `False`. - return dataclasses.replace( - arg, - choices=(True, False), - ) - elif arg.default is False: - # Default `False` => --flag passed in flips to `True`. - return dataclasses.replace( - arg, - action="store_true", - type=None, - field_action=lambda x: x, # argparse will directly give us a bool! - ) - elif arg.default is True: - # Default `True` => --no-flag passed in flips to `False`. - return dataclasses.replace( - arg, - dest=arg.name, - name="no_" + arg.name, - action="store_false", - type=None, - field_action=lambda x: x, # argparse will directly give us a bool! - ) - else: - assert False, "Invalid default" - - def transform_nargs_from_sequences_lists_and_sets( - arg: ArgumentDefinition, - ) -> ArgumentDefinition: - """Transform for handling Sequence[T] and list types.""" - if get_origin(arg.type) in ( - collections.abc.Sequence, # different from typing.Sequence! - list, # different from typing.List! - set, # different from typing.Set! - ): - assert arg.nargs is None, "Sequence types cannot be nested." - (typ,) = map(resolve_typevars, get_args(arg.type)) - container_type = get_origin(arg.type) - if container_type is collections.abc.Sequence: - container_type = list - - return dataclasses.replace( - arg, - type=typ, - # `*` is >=0 values, `+` is >=1 values - # We're going to require at least 1 value; if a user wants to accept no - # input, they can use Optional[Tuple[...]] - nargs="+", - field_action=lambda str_list: container_type( # type: ignore - instance_from_string(typ, x) for x in str_list - ), - ) - else: - return arg - - def transform_nargs_from_tuples(arg: ArgumentDefinition) -> ArgumentDefinition: - """Transform for handling Tuple[T, T, ...] types.""" - - if arg.nargs is None and get_origin(arg.type) is tuple: - assert arg.nargs is None, "Sequence types cannot be nested." - types = tuple(map(resolve_typevars, get_args(arg.type))) - typeset = set(types) # Note that sets are unordered. - typeset_no_ellipsis = typeset - {Ellipsis} # type: ignore - - if typeset_no_ellipsis != typeset: - # Ellipsis: variable argument counts. - assert ( - len(typeset_no_ellipsis) == 1 - ), "If ellipsis is used, tuples must contain only one type." - (typ,) = typeset_no_ellipsis - - return dataclasses.replace( - arg, - # `*` is >=0 values, `+` is >=1 values. - # We're going to require at least 1 value; if a user wants to accept no - # input, they can use Optional[Tuple[...]]. - nargs="+", - type=typ, - field_action=lambda str_list: tuple( - instance_from_string(typ, x) for x in str_list - ), - ) - else: - # Tuples with more than one type. - assert arg.metavar is None - - return dataclasses.replace( - arg, - nargs=len(types), - type=str, # Types are converted in the field action. - metavar=tuple( - t.__name__.upper() if hasattr(t, "__name__") else "X" - for t in types - ), - # Field action: convert lists of strings to tuples of the correct types. - field_action=lambda str_list: tuple( - instance_from_string(typ, x) for typ, x in zip(types, str_list) - ), - ) +def _transform_recursive_instantiator_from_type( + arg: ArgumentDefinition, + type_from_typevar: Dict[TypeVar, Type], +) -> ArgumentDefinition: + """The bulkiest bit: recursively analyze the type annotation and use it to determine how""" + if arg.instantiator is not None: + return arg + instantiator, metadata = _instantiators.instantiator_from_type( + arg.type, # type: ignore + type_from_typevar, + ) + return dataclasses.replace( + arg, + instantiator=instantiator, + choices=metadata.choices, + nargs=metadata.nargs, + required=(not metadata.is_optional) and arg.required, + # Ignore metavar if choices is set. + metavar=metadata.metavar if metadata.choices is None else None, + ) + + +def _transform_generate_helptext(arg: ArgumentDefinition) -> ArgumentDefinition: + """Generate helptext from docstring and argument name.""" + help_parts = [] + docstring_help = _docstrings.get_field_docstring(arg.parent_class, arg.field.name) + if docstring_help is not None: + # Note that the percent symbol needs some extra handling in argparse. + # https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string + docstring_help = docstring_help.replace("%", "%%") + help_parts.append(docstring_help) + + if arg.action is not None: + # Don't show defaults for boolean flags. + assert arg.action in ("store_true", "store_false") + elif arg.default is not None and isinstance(arg.default, enum.Enum): + # Special case for enums. + help_parts.append(f"(default: {arg.default.name})") + elif not arg.required: + # General case. We intentionally don't use the % template, because the default + # will be casted to a string and that can make unnecessary quotation marks + # appear in the helptext. + help_parts.append(f"(default: {arg.default})") + + return dataclasses.replace(arg, help=" ".join(help_parts)) + + +def _transform_convert_defaults_to_strings( + arg: ArgumentDefinition, +) -> ArgumentDefinition: + """Sets all default values to strings, as required as input to our instantiator + functions. Special-cased for enums.""" + + def as_str(x: Any) -> str: + if isinstance(x, enum.Enum): + return x.name else: - return arg - - def transform_choices_from_literals(arg: ArgumentDefinition) -> ArgumentDefinition: - """For literal types, set choices.""" - if get_origin(arg.type) is Literal: - choices = get_args(arg.type) - typ = type(next(iter(choices))) - - assert typ not in ( - list, - tuple, - set, - ), "Containers not supported in literals." - assert all( - map(lambda c: type(c) == typ, choices) - ), "All choices in literal must have the same type!" - - return dataclasses.replace( - arg, - type=typ, - choices=choices, - ) - else: - return arg - - def transform_enums_as_strings(arg: ArgumentDefinition) -> ArgumentDefinition: - """For enums, use string representations.""" - if isinstance(arg.type, type) and issubclass(arg.type, enum.Enum): - if arg.choices is None: - # We use a list and not a set to preserve ordering. - choices = list(x.name for x in arg.type) - else: - # `arg.choices` is set; this occurs when we have enums in a literal - # type. - choices = list(x.name for x in arg.choices) - assert len(choices) == len(set(choices)) - - return dataclasses.replace( - arg, - choices=choices, - type=str, - default=None if arg.default is None else arg.default.name, - field_action=lambda enum_name: arg.type[enum_name], # type: ignore - ) - else: - return arg - - def transform_generate_helptext(arg: ArgumentDefinition) -> ArgumentDefinition: - """Generate helptext from docstring and argument name.""" - if arg.help is None: - help_parts = [] - docstring_help = _docstrings.get_field_docstring( - arg.parent_class, arg.field.name - ) - if docstring_help is not None: - # Note that the percent symbol needs some extra handling in argparse. - # https://stackoverflow.com/questions/21168120/python-argparse-errors-with-in-help-string - docstring_help = docstring_help.replace("%", "%%") - help_parts.append(docstring_help) - - if arg.action is not None: - # Don't show defaults for boolean flags. - assert arg.action in ("store_true", "store_false") - elif arg.default is not None and isinstance(arg.default, enum.Enum): - # Special case for enums. - help_parts.append(f"(default: {arg.default.name})") - elif not arg.required: - # General case. - help_parts.append("(default: %(default)s)") - - return dataclasses.replace(arg, help=" ".join(help_parts)) - else: - return arg - - def transform_use_type_as_metavar(arg: ArgumentDefinition) -> ArgumentDefinition: - """Communicate the argument type using the metavar.""" - if ( - hasattr(arg.type, "__name__") - # Don't generate metavar if target is still wrapping something, eg - # Optional[int] will have 1 arg. - and len(get_args(arg.type)) == 0 - # If choices is set, they'll be used by default. - and arg.choices is None - # Don't generate metavar if one already exists. - and arg.metavar is None - ): - return dataclasses.replace( - arg, metavar=arg.type.__name__.upper() # type: ignore - ) # type: ignore - else: - return arg + return str(x) - return [v for k, v in locals().items() if k.startswith("transform_")] + if arg.default is None or arg.action is not None: + return arg + elif arg.nargs is not None: + return dataclasses.replace(arg, default=tuple(map(as_str, arg.default))) + else: + return dataclasses.replace(arg, default=as_str(arg.default)) diff --git a/dcargs/_construction.py b/dcargs/_construction.py index 6f6278b8..4b01a6da 100644 --- a/dcargs/_construction.py +++ b/dcargs/_construction.py @@ -1,3 +1,5 @@ +"""Core functionality for instantiating dataclasses from argparse namespaces.""" + from typing import TYPE_CHECKING, Any, Dict, Set, Tuple, Type, TypeVar from typing_extensions import get_args @@ -54,7 +56,7 @@ def get_value_from_arg(arg: str) -> Any: value: Any prefixed_field_name = field_name_prefix + field.name - # Resolve field type + # Resolve field type. field_type = ( type_from_typevar[field.type] # type: ignore if field.type in type_from_typevar @@ -67,7 +69,8 @@ def get_value_from_arg(arg: str) -> Any: value = get_value_from_arg(prefixed_field_name) if value is not None: try: - value = arg.field_action(value) + assert arg.instantiator is not None + value = arg.instantiator(value) except ValueError as e: raise FieldActionValueError( f"Parsing error for {arg.get_flag()}: {e.args[0]}" diff --git a/dcargs/_docstrings.py b/dcargs/_docstrings.py index c33b81ea..8606cd5f 100644 --- a/dcargs/_docstrings.py +++ b/dcargs/_docstrings.py @@ -1,3 +1,5 @@ +"""Helpers for parsing dataclass docstrings. Used for helptext generation.""" + import dataclasses import functools import inspect diff --git a/dcargs/_instantiators.py b/dcargs/_instantiators.py new file mode 100644 index 00000000..532676cc --- /dev/null +++ b/dcargs/_instantiators.py @@ -0,0 +1,245 @@ +"""Helper for using type annotations to recursively generate instantiator functions, +which map strings (or, in some cases, sequences of strings) to the annotated type. + +Some examples of type annotations and the desired instantiators: +``` + int + + lambda string: int(str) + + List[int] + + lambda strings: list( + [int(x) for x in strings] + ) + + List[Color], where Color is an enum + + lambda strings: list( + [Color[x] for x in strings] + ) + + Tuple[int, float] + + lambda strings: tuple( + [ + typ(x) + for typ, x in zip( + (int, float), + strings, + ) + ] + ) +``` +""" + +import collections +import dataclasses +import enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union + +from typing_extensions import Final, Literal, _AnnotatedAlias, get_args, get_origin + +from . import _strings + +Instantiator = Union[ + # Most standard fields: these are converted from strings from the CLI. + Callable[[str], Any], + # Sequence fields! This should be used whenever argparse's `nargs` field is set. + Callable[[List[str]], Any], + # Special case: the only time that argparse doesn't give us a string is when the + # argument action is set to `store_true` or `store_false`. In this case, we get + # a bool directly, and the field action can be a no-op. + Callable[[bool], bool], +] + + +@dataclasses.dataclass +class InstantiatorMetadata: + nargs: Optional[Union[str, int]] + metavar: Union[str, Tuple[str, ...]] + choices: Optional[Tuple[Any, ...]] + is_optional: bool + + +class UnsupportedTypeAnnotationError(Exception): + """Exception raised when field actions fail; this typically means that values from + the CLI are invalid.""" + + +def instantiator_from_type( + typ: Type, type_from_typevar: Dict[TypeVar, Type] +) -> Tuple[Instantiator, InstantiatorMetadata]: + """Recursive helper for parsing type annotations. + + Returns two things: + - An instantiator function, for instantiating the type from a string or list of + strings. The latter applies when argparse's `nargs` parameter is set. + - A metadata structure, which specifies parameters for argparse. + """ + + # Resolve typevars. + if typ in type_from_typevar: + return instantiator_from_type( + type_from_typevar[typ], # type: ignore + type_from_typevar, + ) + + # Address container types. If a matching container is found, this will recursively + # call instantiator_from_type(). + container_out = _instantiator_from_container_type(typ, type_from_typevar) + if container_out is not None: + return container_out + + # Construct instantiators for raw types. + auto_choices: Optional[Tuple[str, ...]] = None + if typ is bool: + auto_choices = ("True", "False") + elif issubclass(typ, enum.Enum): + auto_choices = tuple(x.name for x in typ) + return lambda arg: _strings.instance_from_string(typ, arg), InstantiatorMetadata( + nargs=None, + metavar=typ.__name__.upper(), + choices=auto_choices, + is_optional=False, + ) + + +def _instantiator_from_container_type( + typ: Type, type_from_typevar: Dict[TypeVar, Type] +) -> Optional[Tuple[Instantiator, InstantiatorMetadata]]: + """Attempt to create an instantiator from a container type. Returns `None` is no + container type is found.""" + + type_origin = get_origin(typ) + if type_origin is None: + return None + + # Unwrap Final types. + if type_origin is Final: + (contained_type,) = get_args(typ) + return instantiator_from_type(contained_type, type_from_typevar) + + # Unwrap Annotated types. + if hasattr(typ, "__class__") and typ.__class__ == _AnnotatedAlias: + return instantiator_from_type(type_origin, type_from_typevar) + + # List, tuples, and sequences. + if type_origin in ( + collections.abc.Sequence, # different from typing.Sequence! + list, # different from typing.List! + set, # different from typing.Set! + ): + (contained_type,) = get_args(typ) + container_type = type_origin + if container_type is collections.abc.Sequence: + container_type = list + + make, inner_meta = _instantiator_from_type_inner( + contained_type, type_from_typevar + ) + return lambda strings: container_type( + [make(x) for x in strings] + ), InstantiatorMetadata( + nargs="+", + metavar=inner_meta.metavar, + choices=inner_meta.choices, + is_optional=False, + ) + + # Tuples. + if type_origin is tuple: + types = get_args(typ) + typeset = set(types) # Note that sets are unordered. + typeset_no_ellipsis = typeset - {Ellipsis} # type: ignore + + if typeset_no_ellipsis != typeset: + # Ellipsis: variable argument counts. + if len(typeset_no_ellipsis) > 1: + raise UnsupportedTypeAnnotationError( + "When an ellipsis is used, tuples must contain only one type." + ) + (contained_type,) = typeset_no_ellipsis + + make, inner_meta = _instantiator_from_type_inner( + contained_type, type_from_typevar + ) + return lambda strings: tuple( + [make(x) for x in strings] + ), InstantiatorMetadata( + nargs="+", + metavar=inner_meta.metavar, + choices=inner_meta.choices, + is_optional=False, + ) + + else: + instantiators, metas = zip( + *map( + lambda t: _instantiator_from_type_inner(t, type_from_typevar), + types, + ) + ) + if len(set(m.choices for m in metas)) > 1: + raise UnsupportedTypeAnnotationError( + "Due to constraints in argparse, all choices in fixed-length tuples" + " must match. This restricts mixing enums & literals with other" + " types." + ) + return lambda strings: tuple( + make(x) for make, x in zip(instantiators, strings) + ), InstantiatorMetadata( + nargs=len(types), + metavar=tuple(m.metavar for m in metas), + choices=metas[0].choices, + is_optional=False, + ) + + # Optionals. + if type_origin is Union: + options = set(get_args(typ)) + assert ( + len(options) == 2 and type(None) in options + ), "Union must be either over dataclasses (for subparsers) or Optional" + (typ,) = options - {type(None)} + instantiator, metadata = _instantiator_from_type_inner( + typ, type_from_typevar, allow_sequences=True + ) + return instantiator, dataclasses.replace(metadata, is_optional=True) + + # Literals. + if type_origin is Literal: + choices = get_args(typ) + contained_type = type(next(iter(choices))) + assert all( + map(lambda c: type(c) == contained_type, choices) + ), "All choices in literal must have the same type!" + if issubclass(contained_type, enum.Enum): + choices = tuple(map(lambda x: x.name, choices)) + instantiator, metadata = _instantiator_from_type_inner( + contained_type, type_from_typevar + ) + assert ( + # Choices provided by the contained type + metadata.choices is None + or len(set(choices) - set(metadata.choices)) == 0 + ) + return instantiator, dataclasses.replace(metadata, choices=choices) + + return None + + +def _instantiator_from_type_inner( + typ: Type, + type_from_typevar: Dict[TypeVar, Type], + allow_sequences: bool = False, + allow_optional: bool = False, +) -> Tuple[Instantiator, InstantiatorMetadata]: + """Thin wrapper over instantiator_from_type, with some extra asserts for catching + errors.""" + out = instantiator_from_type(typ, type_from_typevar) + if not allow_sequences and out[1].nargs is not None: + raise UnsupportedTypeAnnotationError("Nested sequence types are not supported!") + if not allow_optional and out[1].is_optional: + raise UnsupportedTypeAnnotationError("Nested optional types are not supported!") + return out diff --git a/dcargs/_parsers.py b/dcargs/_parsers.py index e1cf88a5..af3e4382 100644 --- a/dcargs/_parsers.py +++ b/dcargs/_parsers.py @@ -1,3 +1,5 @@ +"""Abstractions for creating argparse parsers from a dataclass definition.""" + import argparse import dataclasses import warnings @@ -5,7 +7,7 @@ from typing_extensions import get_args, get_origin -from . import _arguments, _docstrings, _resolver, _strings +from . import _arguments, _docstrings, _instantiators, _resolver, _strings T = TypeVar("T") @@ -114,9 +116,10 @@ def from_dataclass( if typ in parent_type_from_typevar: type_from_typevar[typevar] = parent_type_from_typevar[typ] # type: ignore - assert ( - cls not in parent_dataclasses - ), f"Found a cyclic dataclass dependency with type {cls}" + if cls in parent_dataclasses: + raise _instantiators.UnsupportedTypeAnnotationError( + f"Found a cyclic dataclass dependency with type {cls}." + ) parent_dataclasses = parent_dataclasses | {cls} args = [] @@ -140,9 +143,10 @@ def from_dataclass( # Try to create subparsers from this field. subparsers_out = nested_handler.handle_unions_over_dataclasses() if subparsers_out is not None: - assert ( - subparsers is None - ), "Only one subparser (union over dataclasses) is allowed per class" + if subparsers is not None: + raise _instantiators.UnsupportedTypeAnnotationError( + "Only one subparser (union over dataclasses) is allowed per class." + ) subparsers = subparsers_out continue @@ -157,12 +161,20 @@ def from_dataclass( continue # Handle simple fields! - arg = _arguments.ArgumentDefinition.make_from_field( - cls, - field, - type_from_typevar, - default=field_default_instance, - ) + try: + arg = _arguments.ArgumentDefinition.make_from_field( + cls, + field, + type_from_typevar, + default=field_default_instance, + ) + except _instantiators.UnsupportedTypeAnnotationError as e: + # Catch unsupported annotation errors, and make the error message more + # informative. + raise _instantiators.UnsupportedTypeAnnotationError( + f"Error when parsing {cls.__name__}.{field.name} of type" + f" {field.type}: {e.args[0]}" + ) args.append(arg) return ParserSpecification( diff --git a/dcargs/_resolver.py b/dcargs/_resolver.py index 5adab5af..2d8d990d 100644 --- a/dcargs/_resolver.py +++ b/dcargs/_resolver.py @@ -1,3 +1,5 @@ +"""Utilities for resolving generic types and forward references.""" + import copy import dataclasses import functools diff --git a/dcargs/_serialization.py b/dcargs/_serialization.py index 20b5d313..924430ba 100644 --- a/dcargs/_serialization.py +++ b/dcargs/_serialization.py @@ -1,3 +1,5 @@ +"""Type-safe, human-readable serialization helpers.""" + import dataclasses import datetime import enum diff --git a/dcargs/_strings.py b/dcargs/_strings.py index e7223e14..7ee84879 100644 --- a/dcargs/_strings.py +++ b/dcargs/_strings.py @@ -1,7 +1,12 @@ +"""Utilities for working with strings.""" + +import enum import functools import re import textwrap -from typing import Type +from typing import Type, TypeVar + +from typing_extensions import get_args from . import _resolver @@ -40,10 +45,38 @@ def subparser_name_from_type(cls: Type) -> str: ) -def bool_from_string(text: str) -> bool: - if text == "True": - return True - elif text == "False": - return False +T = TypeVar("T") + + +def instance_from_string(typ: Type[T], arg: str) -> T: + """Given a type and and a string from the command-line, reconstruct an object. Not + intended to deal with containers. + + This is intended to replace all calls to `type(string)`, which can cause unexpected + behavior. As an example, note that the following argparse code will always print + `True`, because `bool("True") == bool("False") == bool("0") == True`. + ``` + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--flag", type=bool) + + print(parser.parse_args().flag) + ``` + """ + assert len(get_args(typ)) == 0, f"Type {typ} cannot be instantiated." + if typ is bool: + if arg == "True": + return True # type: ignore + elif arg == "False": + return False # type: ignore + else: + raise ValueError(f"Boolean (True/False) expected, but got {arg}.") + elif issubclass(typ, enum.Enum): + try: + return typ[arg] # type: ignore + except KeyError as e: + # Raise enum key errors as value errors. + raise ValueError(*e.args) else: - raise ValueError(f"Boolean (True/False) expected, but got {text}.") + return typ(arg) # type: ignore diff --git a/examples/subparsers.py b/examples/subparsers.py index f2d8dc6f..0e95f5d2 100644 --- a/examples/subparsers.py +++ b/examples/subparsers.py @@ -33,6 +33,5 @@ class Commit: if __name__ == "__main__": - # args = dcargs.parse(Args) - args = dcargs.parse(Args, default_instance=Args(command=Checkout(branch="main"))) + args = dcargs.parse(Args) print(args) diff --git a/setup.py b/setup.py index 11c96159..06f41785 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="dcargs", - version="0.0.18", + version="0.0.19", description="Portable, reusable, strongly typed CLIs from dataclass definitions", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_collections.py b/tests/test_collections.py index bbcee0ff..55556af9 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,7 +1,9 @@ import dataclasses +import enum from typing import List, Optional, Sequence, Set, Tuple import pytest +from typing_extensions import Literal import dcargs @@ -116,13 +118,57 @@ class A: dcargs.parse(A, args=[]) -def test_lists_with_default(): +def test_list_with_literal(): @dataclasses.dataclass class A: - x: List[int] = dataclasses.field(default_factory=[0, 1, 2].copy) + x: List[Literal[1, 2, 3]] - assert dcargs.parse(A, args=[]) == A(x=[0, 1, 2]) assert dcargs.parse(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + with pytest.raises(SystemExit): + dcargs.parse(A, args=["--x", "1", "2", "3", "4"]) + with pytest.raises(SystemExit): + dcargs.parse(A, args=["--x"]) + with pytest.raises(SystemExit): + dcargs.parse(A, args=[]) + + +def test_list_with_enums(): + class Color(enum.Enum): + RED = enum.auto() + GREEN = enum.auto() + BLUE = enum.auto() + + @dataclasses.dataclass + class A: + x: List[Color] + + assert dcargs.parse(A, args=["--x", "RED", "RED", "BLUE"]) == A( + x=[Color.RED, Color.RED, Color.BLUE] + ) + with pytest.raises(SystemExit): + dcargs.parse(A, args=["--x", "RED", "RED", "YELLOW"]) + with pytest.raises(SystemExit): + dcargs.parse(A, args=["--x"]) + with pytest.raises(SystemExit): + dcargs.parse(A, args=[]) + + +def test_lists_with_default(): + class Color(enum.Enum): + RED = enum.auto() + GREEN = enum.auto() + BLUE = enum.auto() + + @dataclasses.dataclass + class A: + x: List[Color] = dataclasses.field( + default_factory=[Color.RED, Color.GREEN].copy + ) + + assert dcargs.parse(A, args=[]) == A(x=[Color.RED, Color.GREEN]) + assert dcargs.parse(A, args=["--x", "RED", "GREEN", "BLUE"]) == A( + x=[Color.RED, Color.GREEN, Color.BLUE] + ) def test_lists_bool(): diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index 723f6693..b24a8cd8 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -142,7 +142,7 @@ class EnumClassB: assert dcargs.parse(EnumClassA, args=["--color", "RED"]) == EnumClassA( color=Color.RED ) - assert dcargs.parse(EnumClassB) == EnumClassB() + assert dcargs.parse(EnumClassB, args=[]) == EnumClassB() def test_literal(): @@ -277,4 +277,4 @@ def test_parse_empty_description(): class A: x: int = 0 - assert dcargs.parse(A, description=None) == A(x=0) + assert dcargs.parse(A, description=None, args=[]) == A(x=0)