From 48491b4a4a36a604201608225e776e63cbf51d46 Mon Sep 17 00:00:00 2001 From: surge119 Date: Mon, 29 Jul 2024 16:57:37 -0700 Subject: [PATCH] Update DeprecatedABCImport Rule (#474) * Updated rule to catch case where import is wrapped in try block * Satisfied typechecker * Added comments * Updated to use ParentNodeProvider * Updated to walk up the node instead of checking a deterministic level * Fixed typo --- src/fixit/rules/deprecated_abc_import.py | 62 +++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py index 14257d5f..03a3607c 100644 --- a/src/fixit/rules/deprecated_abc_import.py +++ b/src/fixit/rules/deprecated_abc_import.py @@ -6,9 +6,10 @@ from typing import List, Optional, Union import libcst as cst - import libcst.matchers as m +from libcst.metadata import ParentNodeProvider + from fixit import Invalid, LintRule, Valid @@ -54,6 +55,7 @@ class DeprecatedABCImport(LintRule): MESSAGE = "ABCs must be imported from collections.abc" PYTHON_VERSION = ">= 3.3" + METADATA_DEPENDENCIES = (ParentNodeProvider,) VALID = [ Valid("from collections.abc import Container"), @@ -71,6 +73,47 @@ def test(self): pass """ ), + Valid( + """ + try: + from collections.abc import Mapping + except ImportError: + from collections import Mapping + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except ImportError: + from collections import Mapping, Container + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except ImportError: + def fallback_import(): + from collections import Mapping, Container + """ + ), + Valid( + """ + try: + from collections.abc import Mapping, Container + except Exception: + exit() + """ + ), + Valid( + """ + try: + from collections import defaultdict + except Exception: + exit() + """ + ), ] INVALID = [ Invalid( @@ -122,10 +165,27 @@ def __init__(self) -> None: # The original imports self.imports_names: List[str] = [] + def is_except_block(self, node: cst.CSTNode) -> bool: + """ + Check if the node is in an except block - if it is, we know to ignore it, as it + may be a fallback import + """ + parent = self.get_metadata(ParentNodeProvider, node) + while not isinstance(parent, cst.Module): + if isinstance(parent, cst.ExceptHandler): + return True + + parent = self.get_metadata(ParentNodeProvider, parent) + + return False + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """ This catches the `from collections import ` cases """ + if self.is_except_block(node): + return + # Get imports in this statement import_names = ( [name.name.value for name in node.names]