diff --git a/CHANGELOG.md b/CHANGELOG.md index c4e1c1632..8784af87a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +- Add new boolean configuration `allow-local-imports` to allow for local imports + ## [2.7.0] - 2024-12-14 ### Enhancements diff --git a/docs/checkers/index.md b/docs/checkers/index.md index 9ca60662d..c33b23bc5 100644 --- a/docs/checkers/index.md +++ b/docs/checkers/index.md @@ -1193,6 +1193,13 @@ allowed-import-modules = random extra-imports = math, tkinter ``` +In addition, you can specify if you want to allow for local imports through `allow-local-imports` option: + +```python +import python_ta +python_ta.check_all(..., config={'allow-local-imports': True}) +``` + (E0401)= ### Import error (E0401) diff --git a/docs/usage/configuration.md b/docs/usage/configuration.md index 79e6af169..470382b7d 100644 --- a/docs/usage/configuration.md +++ b/docs/usage/configuration.md @@ -129,3 +129,15 @@ python_ta.check_all(..., config={'extra-imports': ["math", "tkinter"]}) [FORBIDDEN IMPORT] extra-imports = math, tkinter ``` + +In addition, you can specify `allow-local-imports` to allow local imports. + +```python +import python_ta +python_ta.check_all(..., config={'allow-local-imports': True}) +``` + +```toml +[FORBIDDEN IMPORT] +allow-local-imports = yes +``` diff --git a/examples/custom_checkers/e9999_forbidden_import.py b/examples/custom_checkers/e9999_forbidden_import.py index 72aa4ba78..07602b06d 100644 --- a/examples/custom_checkers/e9999_forbidden_import.py +++ b/examples/custom_checkers/e9999_forbidden_import.py @@ -1,5 +1,6 @@ import copy # Error on this line from sys import path # Error on this line import python_ta # No error +import e9999_forbidden_import_local # Error on this line __import__('math') # Error on this line diff --git a/examples/custom_checkers/e9999_forbidden_import_local.py b/examples/custom_checkers/e9999_forbidden_import_local.py new file mode 100644 index 000000000..e69de29bb diff --git a/python_ta/checkers/forbidden_import_checker.py b/python_ta/checkers/forbidden_import_checker.py index 493a1455a..4c96f12a7 100644 --- a/python_ta/checkers/forbidden_import_checker.py +++ b/python_ta/checkers/forbidden_import_checker.py @@ -1,5 +1,6 @@ """Checker or use of forbidden imports. """ +import os from astroid import nodes from pylint.checkers import BaseChecker @@ -38,19 +39,31 @@ class ForbiddenImportChecker(BaseChecker): "help": "Extra allowed modules to be imported.", }, ), + ( + "allow-local-imports", + { + "default": False, + "type": "yn", + "metavar": "", + "help": "Allow local modules to be imported.", + }, + ), ) @only_required_for_messages("forbidden-import") def visit_import(self, node: nodes.Import) -> None: """visit an Import node""" + local_files = self.get_allowed_local_files() + temp = [ name for name in node.names if name[0] not in self.linter.config.allowed_import_modules and name[0] not in self.linter.config.extra_imports + and name[0] not in local_files ] - if temp != []: + if temp: self.add_message( "forbidden-import", node=node, @@ -63,6 +76,7 @@ def visit_importfrom(self, node: nodes.ImportFrom) -> None: if ( node.modname not in self.linter.config.allowed_import_modules and node.modname not in self.linter.config.extra_imports + and node.modname not in self.get_allowed_local_files() ): self.add_message("forbidden-import", node=node, args=(node.modname, node.lineno)) @@ -77,10 +91,30 @@ def visit_call(self, node: nodes.Call) -> None: if ( node.args[0].value not in self.linter.config.allowed_import_modules and node.args[0].value not in self.linter.config.extra_imports + and node.args[0].value not in self.get_allowed_local_files() ): args = (node.args[0].value, node.lineno) self.add_message("forbidden-import", node=node, args=args) + def get_allowed_local_files(self) -> list: + """ + Returns the list of the local files given by self.linter.current_file + + Returns empty list if current_file is not defined + Returns empty list if local imports are not allowed + """ + if self.linter.current_file is None: + return [] + + if not self.linter.config.allow_local_imports: + return [] + + return [ + f[:-3] + for f in os.listdir(os.path.dirname(self.linter.current_file)) + if f.endswith(".py") + ] + def register(linter: PyLinter) -> None: """Required method to auto register this checker""" diff --git a/python_ta/config/.pylintrc b/python_ta/config/.pylintrc index d467a1aa9..2f8388014 100644 --- a/python_ta/config/.pylintrc +++ b/python_ta/config/.pylintrc @@ -54,6 +54,8 @@ ignore-long-lines = ^\s*((# )??)|(>>>.*)$ allowed-import-modules = dataclasses, doctest, unittest, hypothesis, pytest, python_ta, python_ta.contracts, timeit, typing, __future__ extra-imports = +allow-local-imports = no + [FORBIDDEN IO] # Comma-separated names of functions that are allowed to contain IO actions diff --git a/tests/test.pylintrc b/tests/test.pylintrc index 2f9b5ae14..1f77d33ae 100644 --- a/tests/test.pylintrc +++ b/tests/test.pylintrc @@ -45,6 +45,8 @@ ignore-long-lines = ^\s*((# )??)|(>>>.*)$ allowed-import-modules = dataclasses, doctest, unittest, hypothesis, pytest, python_ta, python_ta.contracts, timeit, typing, __future__ extra-imports = +allow-local-imports = no + [FORBIDDEN IO] # Comma-separated names of functions that are allowed to contain IO actions diff --git a/tests/test_custom_checkers/test_e9999_local_import/imported_module.py b/tests/test_custom_checkers/test_e9999_local_import/imported_module.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_custom_checkers/test_forbidden_import_checker.py b/tests/test_custom_checkers/test_forbidden_import_checker.py new file mode 100644 index 000000000..34903d4e0 --- /dev/null +++ b/tests/test_custom_checkers/test_forbidden_import_checker.py @@ -0,0 +1,120 @@ +import os + +import astroid +import pylint.testutils + +from python_ta.checkers.forbidden_import_checker import ForbiddenImportChecker + + +class TestForbiddenImportChecker(pylint.testutils.CheckerTestCase): + CHECKER_CLASS = ForbiddenImportChecker + CONFIG = {"allowed_import_modules": ["python_ta"], "extra_imports": ["datetime"]} + + def test_forbidden_import_statement(self) -> None: + """Tests for `import XX` statements""" + src = """ + import copy + """ + + mod = astroid.parse(src) + + node, *_ = mod.nodes_of_class(astroid.nodes.Import) + + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-import", node=node, line=1, args=("copy", 2) + ), + ignore_position=True, + ): + self.checker.visit_import(node) + + def test_forbidden_import_from(self) -> None: + """Tests for `from XX import XX` statements""" + src = """ + from sys import path + """ + + mod = astroid.parse(src) + + node, *_ = mod.nodes_of_class(astroid.nodes.ImportFrom) + + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-import", node=node, line=1, args=("sys", 2) + ), + ignore_position=True, + ): + self.checker.visit_importfrom(node) + + def test_allowed_import_statement(self) -> None: + """Tests for `import XX` statements""" + src = """ + import python_ta + """ + + mod = astroid.parse(src) + + node, *_ = mod.nodes_of_class(astroid.nodes.Import) + + with self.assertNoMessages(): + self.checker.visit_import(node) + + def test_extra_import_statement(self) -> None: + src = """ + import datetime + """ + + mod = astroid.parse(src) + + node, *_ = mod.nodes_of_class(astroid.nodes.Import) + + with self.assertNoMessages(): + self.checker.visit_import(node) + + def test_forbidden_dunder_import(self) -> None: + src = """ + __import__('math') + """ + mod = astroid.parse(src) + + node, *_ = mod.nodes_of_class(astroid.nodes.Call) + + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-import", node=node, line=1, args=("math", 2) + ), + ignore_position=True, + ): + self.checker.visit_call(node) + + @pylint.testutils.set_config(allow_local_imports=True) + def test_allowed_local_import(self) -> None: + src = """ + import imported_module + """ + + self.linter.current_file = os.path.abspath(__file__ + "/../test_e9999_local_import/main.py") + + mod = astroid.parse(src) + node, *_ = mod.nodes_of_class(astroid.nodes.Import) + + with self.assertNoMessages(): + self.checker.visit_import(node) + + def test_disallowed_local_import(self) -> None: + src = """ + import imported_module + """ + + self.linter.current_file = os.path.abspath(__file__ + "/../test_e9999_local_import/main.py") + + mod = astroid.parse(src) + node, *_ = mod.nodes_of_class(astroid.nodes.Import) + + with self.assertAddsMessages( + pylint.testutils.MessageTest( + msg_id="forbidden-import", node=node, line=1, args=("imported_module", 2) + ), + ignore_position=True, + ): + self.checker.visit_import(node)