diff --git a/dis_cli.py b/dis_cli.py index b86bd74..c6b8ed4 100644 --- a/dis_cli.py +++ b/dis_cli.py @@ -41,7 +41,7 @@ RE_JUMP = re.compile(r"to (\d+)") INSTRUCTION_GRID_HEADERS = ["OFF", "OPERATION", "ARGS", ""] -T_INSTRUCTION_ROW = Tuple[Text, ...] +T_INSTRUCTION_ROW = Union[Tuple[Text, ...], str] T_CLASS_OR_MODULE = Union[type, ModuleType] T_FUNCTION_OR_CLASS_OR_MODULE = Union[FunctionType, T_CLASS_OR_MODULE] @@ -119,7 +119,7 @@ class Target: imported_from: Optional[ModuleType] = None @cached_property - def module(self): + def module(self) -> Optional[ModuleType]: return self.imported_from or (self.obj if self.is_module else inspect.getmodule(self.obj)) @cached_property @@ -203,7 +203,7 @@ def silent_import(module_path: str) -> ModuleType: ) -def cannot_be_disassembled(target: Target): +def cannot_be_disassembled(target: Target) -> None: msg = f"The target {target.path} = {target.obj} is a {type(target.obj).__name__}, which cannot be disassembled. Target a specific function" possible_targets = find_child_targets(target) @@ -224,7 +224,7 @@ def find_child_targets(target: Target) -> List[Target]: return list(_find_child_targets(target, top_module=target.module)) -def _find_child_targets(target: Target, top_module: ModuleType) -> Iterable[Target]: +def _find_child_targets(target: Target, top_module: Optional[ModuleType]) -> Iterable[Target]: try: children = vars(target.obj) except TypeError: # vars() argument must have __dict__ attribute @@ -282,7 +282,7 @@ def make_source_and_bytecode_display_for_function(function: FunctionType, theme: ) -def make_title(function, start_line: int) -> Text: +def make_title(function: FunctionType, start_line: int) -> Text: source_file_path = Path(inspect.getmodule(function).__file__) try: source_file_path = source_file_path.relative_to(Path.cwd()) diff --git a/tests/conftest.py b/tests/conftest.py index c939806..4bc0288 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,10 @@ import sys import traceback from pathlib import Path -from typing import List +from typing import Callable, Iterator, List import pytest +from click import Command from click.testing import CliRunner, Result import dis_cli @@ -30,10 +31,13 @@ def runner() -> CliRunner: return CliRunner() -def invoke_with_debug(runner: CliRunner, cli, command: List[str], **kwargs) -> Result: +T_CLI = Callable[[List[str]], Result] + + +def invoke_with_debug(runner: CliRunner, cli: Command, command: List[str]) -> Result: command.append("--no-paging") print("command:", command) - result = runner.invoke(cli=cli, args=command, **kwargs) + result = runner.invoke(cli=cli, args=command) print("result:", result) @@ -49,12 +53,12 @@ def invoke_with_debug(runner: CliRunner, cli, command: List[str], **kwargs) -> R @pytest.fixture(scope="session") -def cli(runner): +def cli(runner: CliRunner) -> T_CLI: return functools.partial(invoke_with_debug, runner, dis_cli.cli) @pytest.fixture -def test_dir(tmp_path) -> Path: +def test_dir(tmp_path: Path) -> Iterator[Path]: cwd = os.getcwd() os.chdir(tmp_path) try: diff --git a/tests/test_cli.py b/tests/test_cli.py index 6118862..ff87832 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,8 +6,10 @@ from dis_cli import calculate_column_widths +from .conftest import T_CLI -def test_smoke(cli): + +def test_smoke(cli: T_CLI) -> None: assert cli(["dis.dis"]).exit_code == 0 @@ -34,26 +36,26 @@ def {METHOD_NAME}(self): @pytest.fixture -def source_path(test_dir, filename) -> Path: +def source_path(test_dir: Path, filename: str) -> Path: source_path = test_dir / f"{filename}.py" source_path.write_text(SOURCE) return source_path -def test_runs_successfully_on_func_source(cli, source_path): +def test_runs_successfully_on_func_source(cli: T_CLI, source_path: Path) -> None: result = cli([f"{source_path.stem}.{FUNC_NAME}"]) assert result.exit_code == 0 -def test_func_source_in_output(cli, source_path): +def test_func_source_in_output(cli: T_CLI, source_path: Path) -> None: result = cli([f"{source_path.stem}.{FUNC_NAME}"]) assert f"def {FUNC_NAME}():" in result.output -def test_handle_missing_target_gracefully(cli, source_path): +def test_handle_missing_target_gracefully(cli: T_CLI, source_path: Path) -> None: result = cli([f"{source_path.stem}.{FUNC_NAME}osidjafoa"]) assert result.exit_code == 1 @@ -64,7 +66,9 @@ def test_handle_missing_target_gracefully(cli, source_path): @pytest.mark.parametrize( "extra_source", ["print('hi')", "import sys\nprint('hi', file=sys.stderr)"] ) -def test_module_level_output_is_not_shown(cli, test_dir, filename, extra_source): +def test_module_level_output_is_not_shown( + cli: T_CLI, test_dir: Path, filename: str, extra_source: str +) -> None: source_path = test_dir / f"{filename}.py" source_path.write_text(f"{extra_source}\n{SOURCE}") print(source_path.read_text()) @@ -76,7 +80,9 @@ def test_module_level_output_is_not_shown(cli, test_dir, filename, extra_source) @pytest.mark.parametrize("extra_source", ["raise Exception", "syntax error"]) -def test_module_level_error_is_handled_gracefully(cli, test_dir, filename, extra_source): +def test_module_level_error_is_handled_gracefully( + cli: T_CLI, test_dir: Path, filename: str, extra_source: str +) -> None: source_path = test_dir / f"{filename}.py" source_path.write_text(f"{extra_source}\n{SOURCE}") print(source_path.read_text()) @@ -87,7 +93,9 @@ def test_module_level_error_is_handled_gracefully(cli, test_dir, filename, extra assert "during import" in result.output -def test_targeting_a_class_targets_all_of_its_methods(cli, test_dir, filename): +def test_targeting_a_class_targets_all_of_its_methods( + cli: T_CLI, test_dir: Path, filename: str +) -> None: source_path = test_dir / f"{filename}.py" source_path.write_text( textwrap.dedent( @@ -113,7 +121,7 @@ def method(self): @pytest.mark.skipif( sys.version_info < (3, 7), reason="Non-native dataclasses don't behave the same" ) -def test_can_dis_dataclass(cli, test_dir, filename): +def test_can_dis_dataclass(cli: T_CLI, test_dir: Path, filename: str) -> None: """ Dataclasses have generated methods with no matching source that we need a special case for. """ @@ -137,7 +145,7 @@ class Foo: assert "NO SOURCE CODE FOUND" in result.output -def test_targeting_a_module_targets_its_members(cli, test_dir, filename): +def test_targeting_a_module_targets_its_members(cli: T_CLI, test_dir: Path, filename: str) -> None: source_path = test_dir / f"{filename}.py" source_path.write_text( textwrap.dedent( @@ -168,14 +176,14 @@ def method(self): assert "wizbang" in result.output -def test_can_target_method(cli, source_path): +def test_can_target_method(cli: T_CLI, source_path: Path) -> None: result = cli([f"{source_path.stem}.{CLASS_NAME}.{METHOD_NAME}"]) assert result.exit_code == 0 assert METHOD_NAME in result.output -def test_module_not_found(cli): +def test_module_not_found(cli: T_CLI) -> None: target = "fidsjofoiasjoifdj" result = cli([target]) @@ -185,15 +193,15 @@ def test_module_not_found(cli): @pytest.mark.parametrize( - "target", + "target_path", [ f"{CONST_NAME}", f"{CLASS_NAME}.{ATTR_NAME}", f"{NONE_NAME}", ], ) -def test_cannot_be_disassembled(cli, source_path, target): - result = cli([f"{source_path.stem}.{target}"]) +def test_cannot_be_disassembled(cli: T_CLI, source_path: Path, target_path: str) -> None: + result = cli([f"{source_path.stem}.{target_path}"]) assert result.exit_code == 1 assert "cannot be disassembled" in result.output @@ -214,14 +222,14 @@ def test_cannot_be_disassembled(cli, source_path, target): (80, 4, 0.75, (52, 17)), ], ) -def test_column_width(terminal_width, line_num_width, ratio, expected): +def test_column_width(terminal_width, line_num_width, ratio, expected) -> None: assert ( calculate_column_widths(line_num_width, ratio=ratio, terminal_width=terminal_width) == expected ) -def test_no_targets_prints_help(cli): +def test_no_targets_prints_help(cli) -> None: result = cli([]) assert result.exit_code == 0