Skip to content
This repository has been archived by the owner on Oct 7, 2023. It is now read-only.

Commit

Permalink
Type hint improvements (#18)
Browse files Browse the repository at this point in the history
* types for tests

* a few more small improvements
  • Loading branch information
JoshKarpel authored Nov 27, 2020
1 parent b761ef0 commit 0e1397e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 27 deletions.
10 changes: 5 additions & 5 deletions dis_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
RE_JUMP = re.compile(r"to (\d+)")

INSTRUCTION_GRID_HEADERS = ["OFF", "OPERATION", "ARGS", ""]
T_INSTRUCTION_ROW = Tuple[Text, ...]
T_INSTRUCTION_ROW = Union[Tuple[Text, ...], str]

T_CLASS_OR_MODULE = Union[type, ModuleType]
T_FUNCTION_OR_CLASS_OR_MODULE = Union[FunctionType, T_CLASS_OR_MODULE]
Expand Down Expand Up @@ -119,7 +119,7 @@ class Target:
imported_from: Optional[ModuleType] = None

@cached_property
def module(self):
def module(self) -> Optional[ModuleType]:
return self.imported_from or (self.obj if self.is_module else inspect.getmodule(self.obj))

@cached_property
Expand Down Expand Up @@ -203,7 +203,7 @@ def silent_import(module_path: str) -> ModuleType:
)


def cannot_be_disassembled(target: Target):
def cannot_be_disassembled(target: Target) -> None:
msg = f"The target {target.path} = {target.obj} is a {type(target.obj).__name__}, which cannot be disassembled. Target a specific function"

possible_targets = find_child_targets(target)
Expand All @@ -224,7 +224,7 @@ def find_child_targets(target: Target) -> List[Target]:
return list(_find_child_targets(target, top_module=target.module))


def _find_child_targets(target: Target, top_module: ModuleType) -> Iterable[Target]:
def _find_child_targets(target: Target, top_module: Optional[ModuleType]) -> Iterable[Target]:
try:
children = vars(target.obj)
except TypeError: # vars() argument must have __dict__ attribute
Expand Down Expand Up @@ -282,7 +282,7 @@ def make_source_and_bytecode_display_for_function(function: FunctionType, theme:
)


def make_title(function, start_line: int) -> Text:
def make_title(function: FunctionType, start_line: int) -> Text:
source_file_path = Path(inspect.getmodule(function).__file__)
try:
source_file_path = source_file_path.relative_to(Path.cwd())
Expand Down
14 changes: 9 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import sys
import traceback
from pathlib import Path
from typing import List
from typing import Callable, Iterator, List

import pytest
from click import Command
from click.testing import CliRunner, Result

import dis_cli
Expand All @@ -30,10 +31,13 @@ def runner() -> CliRunner:
return CliRunner()


def invoke_with_debug(runner: CliRunner, cli, command: List[str], **kwargs) -> Result:
T_CLI = Callable[[List[str]], Result]


def invoke_with_debug(runner: CliRunner, cli: Command, command: List[str]) -> Result:
command.append("--no-paging")
print("command:", command)
result = runner.invoke(cli=cli, args=command, **kwargs)
result = runner.invoke(cli=cli, args=command)

print("result:", result)

Expand All @@ -49,12 +53,12 @@ def invoke_with_debug(runner: CliRunner, cli, command: List[str], **kwargs) -> R


@pytest.fixture(scope="session")
def cli(runner):
def cli(runner: CliRunner) -> T_CLI:
return functools.partial(invoke_with_debug, runner, dis_cli.cli)


@pytest.fixture
def test_dir(tmp_path) -> Path:
def test_dir(tmp_path: Path) -> Iterator[Path]:
cwd = os.getcwd()
os.chdir(tmp_path)
try:
Expand Down
42 changes: 25 additions & 17 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from dis_cli import calculate_column_widths

from .conftest import T_CLI

def test_smoke(cli):

def test_smoke(cli: T_CLI) -> None:
assert cli(["dis.dis"]).exit_code == 0


Expand All @@ -34,26 +36,26 @@ def {METHOD_NAME}(self):


@pytest.fixture
def source_path(test_dir, filename) -> Path:
def source_path(test_dir: Path, filename: str) -> Path:
source_path = test_dir / f"{filename}.py"
source_path.write_text(SOURCE)

return source_path


def test_runs_successfully_on_func_source(cli, source_path):
def test_runs_successfully_on_func_source(cli: T_CLI, source_path: Path) -> None:
result = cli([f"{source_path.stem}.{FUNC_NAME}"])

assert result.exit_code == 0


def test_func_source_in_output(cli, source_path):
def test_func_source_in_output(cli: T_CLI, source_path: Path) -> None:
result = cli([f"{source_path.stem}.{FUNC_NAME}"])

assert f"def {FUNC_NAME}():" in result.output


def test_handle_missing_target_gracefully(cli, source_path):
def test_handle_missing_target_gracefully(cli: T_CLI, source_path: Path) -> None:
result = cli([f"{source_path.stem}.{FUNC_NAME}osidjafoa"])

assert result.exit_code == 1
Expand All @@ -64,7 +66,9 @@ def test_handle_missing_target_gracefully(cli, source_path):
@pytest.mark.parametrize(
"extra_source", ["print('hi')", "import sys\nprint('hi', file=sys.stderr)"]
)
def test_module_level_output_is_not_shown(cli, test_dir, filename, extra_source):
def test_module_level_output_is_not_shown(
cli: T_CLI, test_dir: Path, filename: str, extra_source: str
) -> None:
source_path = test_dir / f"{filename}.py"
source_path.write_text(f"{extra_source}\n{SOURCE}")
print(source_path.read_text())
Expand All @@ -76,7 +80,9 @@ def test_module_level_output_is_not_shown(cli, test_dir, filename, extra_source)


@pytest.mark.parametrize("extra_source", ["raise Exception", "syntax error"])
def test_module_level_error_is_handled_gracefully(cli, test_dir, filename, extra_source):
def test_module_level_error_is_handled_gracefully(
cli: T_CLI, test_dir: Path, filename: str, extra_source: str
) -> None:
source_path = test_dir / f"{filename}.py"
source_path.write_text(f"{extra_source}\n{SOURCE}")
print(source_path.read_text())
Expand All @@ -87,7 +93,9 @@ def test_module_level_error_is_handled_gracefully(cli, test_dir, filename, extra
assert "during import" in result.output


def test_targeting_a_class_targets_all_of_its_methods(cli, test_dir, filename):
def test_targeting_a_class_targets_all_of_its_methods(
cli: T_CLI, test_dir: Path, filename: str
) -> None:
source_path = test_dir / f"{filename}.py"
source_path.write_text(
textwrap.dedent(
Expand All @@ -113,7 +121,7 @@ def method(self):
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="Non-native dataclasses don't behave the same"
)
def test_can_dis_dataclass(cli, test_dir, filename):
def test_can_dis_dataclass(cli: T_CLI, test_dir: Path, filename: str) -> None:
"""
Dataclasses have generated methods with no matching source that we need a special case for.
"""
Expand All @@ -137,7 +145,7 @@ class Foo:
assert "NO SOURCE CODE FOUND" in result.output


def test_targeting_a_module_targets_its_members(cli, test_dir, filename):
def test_targeting_a_module_targets_its_members(cli: T_CLI, test_dir: Path, filename: str) -> None:
source_path = test_dir / f"{filename}.py"
source_path.write_text(
textwrap.dedent(
Expand Down Expand Up @@ -168,14 +176,14 @@ def method(self):
assert "wizbang" in result.output


def test_can_target_method(cli, source_path):
def test_can_target_method(cli: T_CLI, source_path: Path) -> None:
result = cli([f"{source_path.stem}.{CLASS_NAME}.{METHOD_NAME}"])

assert result.exit_code == 0
assert METHOD_NAME in result.output


def test_module_not_found(cli):
def test_module_not_found(cli: T_CLI) -> None:
target = "fidsjofoiasjoifdj"
result = cli([target])

Expand All @@ -185,15 +193,15 @@ def test_module_not_found(cli):


@pytest.mark.parametrize(
"target",
"target_path",
[
f"{CONST_NAME}",
f"{CLASS_NAME}.{ATTR_NAME}",
f"{NONE_NAME}",
],
)
def test_cannot_be_disassembled(cli, source_path, target):
result = cli([f"{source_path.stem}.{target}"])
def test_cannot_be_disassembled(cli: T_CLI, source_path: Path, target_path: str) -> None:
result = cli([f"{source_path.stem}.{target_path}"])

assert result.exit_code == 1
assert "cannot be disassembled" in result.output
Expand All @@ -214,14 +222,14 @@ def test_cannot_be_disassembled(cli, source_path, target):
(80, 4, 0.75, (52, 17)),
],
)
def test_column_width(terminal_width, line_num_width, ratio, expected):
def test_column_width(terminal_width, line_num_width, ratio, expected) -> None:
assert (
calculate_column_widths(line_num_width, ratio=ratio, terminal_width=terminal_width)
== expected
)


def test_no_targets_prints_help(cli):
def test_no_targets_prints_help(cli) -> None:
result = cli([])

assert result.exit_code == 0
Expand Down

0 comments on commit 0e1397e

Please sign in to comment.