diff --git a/argparse_dataclass.py b/argparse_dataclass.py index b93ac8e..13df728 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -316,13 +316,27 @@ def parse_known_args( return options_class(**kwargs), others +def _fields(options_class: typing.Type[OptionsType]) -> typing.Tuple[Field, ...]: + """Get tuple of Field for dataclass.""" + type_hints = typing.get_type_hints(options_class) + + def _ensure_type(_f): + # When importing __future__.annotations, `Field.type` becomes `str` + # Ref: https://github.com/mivade/argparse_dataclass/issues/47 + if isinstance(_f.type, str): + _f.type = type_hints[_f.name] + return _f + + return tuple(_ensure_type(_f) for _f in fields(options_class)) + + def _add_dataclass_options( options_class: typing.Type[OptionsType], parser: argparse.ArgumentParser ) -> None: if not is_dataclass(options_class): raise TypeError("cls must be a dataclass") - for field in fields(options_class): + for field in _fields(options_class): args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) positional = not args[0].startswith("-") kwargs = { diff --git a/tests/test_annotations.py b/tests/test_annotations.py new file mode 100644 index 0000000..c514174 --- /dev/null +++ b/tests/test_annotations.py @@ -0,0 +1,23 @@ +from __future__ import annotations +import unittest +from argparse_dataclass import dataclass + + +@dataclass +class Opt: + x: int = 42 + y: bool = False + + +class ArgParseTests(unittest.TestCase): + def test_basic(self): + params = Opt.parse_args([]) + self.assertEqual(42, params.x) + self.assertEqual(False, params.y) + params = Opt.parse_args(["--x=10", "--y"]) + self.assertEqual(10, params.x) + self.assertEqual(True, params.y) + + +if __name__ == "__main__": + unittest.main()