diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py index 14257d5f..03a3607c 100644 --- a/src/fixit/rules/deprecated_abc_import.py +++ b/src/fixit/rules/deprecated_abc_import.py @@ -6,9 +6,10 @@ from typing import List, Optional, Union import libcst as cst - import libcst.matchers as m +from libcst.metadata import ParentNodeProvider + from fixit import Invalid, LintRule, Valid @@ -54,6 +55,7 @@ class DeprecatedABCImport(LintRule): MESSAGE = "ABCs must be imported from collections.abc" PYTHON_VERSION = ">= 3.3" + METADATA_DEPENDENCIES = (ParentNodeProvider,) VALID = [ Valid("from collections.abc import Container"), @@ -71,6 +73,47 @@ def test(self): pass """ ), + Valid( + """ + try: + from collections.abc import Mapping + except ImportError: + from collections import Mapping + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except ImportError: + from collections import Mapping, Container + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except ImportError: + def fallback_import(): + from collections import Mapping, Container + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except Exception: + exit() + """ + ), + Valid( + """ + try: + from collections import defaultdict + except Exception: + exit() + """ + ), ] INVALID = [ Invalid( @@ -122,10 +165,27 @@ def __init__(self) -> None: # The original imports self.imports_names: List[str] = [] + def is_except_block(self, node: cst.CSTNode) -> bool: + """ + Check if the node is in an except block - if it is, we know to ignore it, as it + may be a fallback import + """ + parent = self.get_metadata(ParentNodeProvider, node) + while not isinstance(parent, cst.Module): + if isinstance(parent, cst.ExceptHandler): + return True + + parent = self.get_metadata(ParentNodeProvider, parent) + + return False + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """ This catches the `from collections import ` cases """ + if self.is_except_block(node): + return + # Get imports in this statement import_names = ( [name.name.value for name in node.names]