From 60b286d937bd8f487c737c8cce21712ed0a47e34 Mon Sep 17 00:00:00 2001 From: Joko013 <30841710+joko013@users.noreply.github.com> Date: Sat, 14 Dec 2024 21:23:12 +0100 Subject: [PATCH] Add forward reference annotation detection --- tests/test_annotations_parser.py | 22 ++++++++++++ tests/test_imports.py | 60 ++++++++++++++++++++++++++++++++ vulture/annotation_parser.py | 36 +++++++++++++++++++ vulture/core.py | 16 +++++++++ 4 files changed, 134 insertions(+) create mode 100644 tests/test_annotations_parser.py create mode 100644 vulture/annotation_parser.py diff --git a/tests/test_annotations_parser.py b/tests/test_annotations_parser.py new file mode 100644 index 00000000..612c1bd2 --- /dev/null +++ b/tests/test_annotations_parser.py @@ -0,0 +1,22 @@ +import typing + +import pytest + +from vulture.annotation_parser import AnnotationParser + + +@pytest.mark.parametrize( + "input_annotation, expected_types", + [ + ("Foo", {"Foo"}), + ("foo.Foo", {"foo", "Foo"}), + ("List['Foo', 'Bar']", {"List", "Foo", "Bar"}), + ('List["Foo", "Bar"]', {"List", "Foo", "Bar"}), + ("List['foo.Foo', 'foo.Bar']", {"List", "foo", "Foo", "Bar"}), + ], +) +def test_different_nested_annotations( + input_annotation: str, expected_types: typing.Set[str] +): + parser = AnnotationParser(input_annotation) + assert parser.parse() == expected_types diff --git a/tests/test_imports.py b/tests/test_imports.py index af004d33..c18b3eae 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -309,3 +309,63 @@ def test_ignore_init_py_files(v): ) check(v.unused_imports, []) check(v.unused_vars, ["unused_var"]) + + +def test_plain_forward_reference_args(v): + v.scan( + """\ +from foo import Foo, Bar, Baz + +def bar(a: "Foo", /, b: "Bar", *, c: "Baz"): + ... +""" + ) + check(v.unused_imports, []) + + +def test_plain_forward_reference_return_type(v): + v.scan( + """\ +from foo import Foo + +def bar() -> "Foo": + ... +""" + ) + check(v.unused_imports, []) + + +def test_plain_forward_reference_with_module(v): + v.scan( + """\ +import foo + +def bar() -> "foo.Foo": + ... +""" + ) + check(v.unused_imports, []) + + +def test_nested_forward_reference_outer_double_quotes(v): + v.scan( + """\ +from foo import Foo + +def bar() -> "List['Foo']": + ... +""" + ) + check(v.unused_imports, []) + + +def test_nested_forward_reference_outer_single_quotes(v): + v.scan( + """\ +from foo import Foo + +def bar() -> 'List["Foo"]': + ... +""" + ) + check(v.unused_imports, []) diff --git a/vulture/annotation_parser.py b/vulture/annotation_parser.py new file mode 100644 index 00000000..7331c150 --- /dev/null +++ b/vulture/annotation_parser.py @@ -0,0 +1,36 @@ +import tokenize +from io import StringIO +from typing import List, Set + + +class AnnotationParser: + def __init__(self, annotation_string: str): + self._annotation = annotation_string + + def parse(self) -> Set[str]: + type_names = set() + token_generator = tokenize.generate_tokens( + StringIO(self._annotation).readline + ) + + for token in token_generator: + token_type = token.type + token_string = token.string + + if token_type == tokenize.NAME: + type_names.add(token_string) + elif token_type == tokenize.STRING: + for type_name in self._parse_string(token_string): + type_names.add(type_name) + + return type_names + + @staticmethod + def _parse_string(token_string: str) -> List[str]: + first_char = token_string[0] + if first_char == "'" or first_char == '"': + type_name = token_string[1:-1] + else: + type_name = token_string + + return type_name.split(".") diff --git a/vulture/core.py b/vulture/core.py index cc301b71..08b1df24 100644 --- a/vulture/core.py +++ b/vulture/core.py @@ -9,6 +9,7 @@ from typing import List from vulture import lines, noqa, utils +from vulture.annotation_parser import AnnotationParser from vulture.config import InputError, make_config from vulture.reachability import Reachability from vulture.utils import ExitCode @@ -597,6 +598,21 @@ def visit_FunctionDef(self, node): self.defined_funcs, node.name, node, ignore=_ignore_function ) + for arg in ( + node.args.args + node.args.kwonlyargs + node.args.posonlyargs + ): + self._add_constant_annotation(arg.annotation) + + if node.returns: + self._add_constant_annotation(node.returns) + + def _add_constant_annotation(self, annotation: ast.AST): + if utils.is_ast_string(annotation): + annotation: ast.Constant + annotation_parser = AnnotationParser(annotation.value) + for name in annotation_parser.parse(): + self.used_names.add(name) + def visit_Import(self, node): self._add_aliases(node)