Skip to content

Commit

Permalink
Add forward reference annotation detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Joko013 committed Dec 15, 2024
1 parent 4ecc149 commit 60b286d
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_annotations_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import typing

import pytest

from vulture.annotation_parser import AnnotationParser


@pytest.mark.parametrize(
"input_annotation, expected_types",
[
("Foo", {"Foo"}),
("foo.Foo", {"foo", "Foo"}),
("List['Foo', 'Bar']", {"List", "Foo", "Bar"}),
('List["Foo", "Bar"]', {"List", "Foo", "Bar"}),
("List['foo.Foo', 'foo.Bar']", {"List", "foo", "Foo", "Bar"}),
],
)
def test_different_nested_annotations(
input_annotation: str, expected_types: typing.Set[str]
):
parser = AnnotationParser(input_annotation)
assert parser.parse() == expected_types
60 changes: 60 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,63 @@ def test_ignore_init_py_files(v):
)
check(v.unused_imports, [])
check(v.unused_vars, ["unused_var"])


def test_plain_forward_reference_args(v):
v.scan(
"""\
from foo import Foo, Bar, Baz
def bar(a: "Foo", /, b: "Bar", *, c: "Baz"):
...
"""
)
check(v.unused_imports, [])


def test_plain_forward_reference_return_type(v):
v.scan(
"""\
from foo import Foo
def bar() -> "Foo":
...
"""
)
check(v.unused_imports, [])


def test_plain_forward_reference_with_module(v):
v.scan(
"""\
import foo
def bar() -> "foo.Foo":
...
"""
)
check(v.unused_imports, [])


def test_nested_forward_reference_outer_double_quotes(v):
v.scan(
"""\
from foo import Foo
def bar() -> "List['Foo']":
...
"""
)
check(v.unused_imports, [])


def test_nested_forward_reference_outer_single_quotes(v):
v.scan(
"""\
from foo import Foo
def bar() -> 'List["Foo"]':
...
"""
)
check(v.unused_imports, [])
36 changes: 36 additions & 0 deletions vulture/annotation_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tokenize
from io import StringIO
from typing import List, Set


class AnnotationParser:
def __init__(self, annotation_string: str):
self._annotation = annotation_string

def parse(self) -> Set[str]:
type_names = set()
token_generator = tokenize.generate_tokens(
StringIO(self._annotation).readline
)

for token in token_generator:
token_type = token.type
token_string = token.string

if token_type == tokenize.NAME:
type_names.add(token_string)
elif token_type == tokenize.STRING:
for type_name in self._parse_string(token_string):
type_names.add(type_name)

return type_names

@staticmethod
def _parse_string(token_string: str) -> List[str]:
first_char = token_string[0]
if first_char == "'" or first_char == '"':
type_name = token_string[1:-1]
else:
type_name = token_string

return type_name.split(".")
16 changes: 16 additions & 0 deletions vulture/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List

from vulture import lines, noqa, utils
from vulture.annotation_parser import AnnotationParser
from vulture.config import InputError, make_config
from vulture.reachability import Reachability
from vulture.utils import ExitCode
Expand Down Expand Up @@ -597,6 +598,21 @@ def visit_FunctionDef(self, node):
self.defined_funcs, node.name, node, ignore=_ignore_function
)

for arg in (
node.args.args + node.args.kwonlyargs + node.args.posonlyargs
):
self._add_constant_annotation(arg.annotation)

if node.returns:
self._add_constant_annotation(node.returns)

def _add_constant_annotation(self, annotation: ast.AST):
if utils.is_ast_string(annotation):
annotation: ast.Constant
annotation_parser = AnnotationParser(annotation.value)
for name in annotation_parser.parse():
self.used_names.add(name)

def visit_Import(self, node):
self._add_aliases(node)

Expand Down

0 comments on commit 60b286d

Please sign in to comment.