diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7c9a835..b335d0a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,16 @@ repos: rev: v1.1.9 hooks: - id: remove-crlf + - repo: https://github.com/asottile/blacken-docs + rev: v1.8.0 + hooks: + - id: blacken-docs + additional_dependencies: [ black==20.8b1 ] + - repo: https://github.com/hadialqattan/pycln + rev: 0.0.1-alpha.3 + hooks: + - id: pycln + args: [ --config=pyproject.toml ] - repo: https://github.com/psf/black rev: 20.8b1 hooks: diff --git a/dis_cli.py b/dis_cli.py index 7c667a5..e9b0185 100644 --- a/dis_cli.py +++ b/dis_cli.py @@ -2,6 +2,7 @@ import contextlib import dis +import functools import importlib import inspect import itertools @@ -15,7 +16,7 @@ from dataclasses import dataclass from pathlib import Path from types import FunctionType, ModuleType -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union import click from rich.color import ANSI_COLOR_NAMES @@ -27,6 +28,11 @@ from rich.table import Table from rich.text import Text +if sys.version_info >= (3, 8): + from functools import cached_property +else: + cached_property = lambda func: property(functools.lru_cache(maxsize=None)(func)) + T_JUMP_COLOR_MAP = Dict[int, str] JUMP_COLORS = [ c for c in ANSI_COLOR_NAMES.keys() if not any(("grey" in c, "black" in c, "white" in c)) @@ -36,20 +42,29 @@ INSTRUCTION_GRID_HEADERS = ["OFF", "OPERATION", "ARGS", ""] T_INSTRUCTION_ROW = Tuple[Text, ...] +T_CLASS_OR_MODULE = Union[type, ModuleType] +T_FUNCTION_OR_CLASS_OR_MODULE = Union[FunctionType, T_CLASS_OR_MODULE] NUMBER_COLUMN_WIDTH = 4 +DEFAULT_THEME = "monokai" + @click.command() -@click.argument("target", nargs=-1) +@click.argument( + "target", + nargs=-1, +) @click.option( - "--theme", default="monokai", help="Choose the syntax highlighting theme (any Pygments theme)." + "--theme", + default=DEFAULT_THEME, + help=f"Choose the syntax highlighting theme (any Pygments theme). Default: {DEFAULT_THEME!r}.", ) @click.option( "-p/-P", "--paging/--no-paging", default=None, - help="Enable/disable displaying output using the system pager. If not passed explicitly, the pager will automatically be used if the output is taller than your terminal.", + help="Enable/disable displaying output using the system pager. Default: enabled if the output is taller than your terminal window.", ) @click.version_option() def cli( @@ -60,7 +75,8 @@ def cli( """ Display the source and bytecode of the TARGET Python functions. - If you TARGET a class, its __init__ method will be displayed. + If you TARGET a class, all of its methods will be targeted. + If you TARGET a module, all of its functions and classes (and therefore their methods) will be targeted. Any number of TARGETs may be passed; they will be displayed sequentially. """ @@ -71,7 +87,7 @@ def cli( console = Console(highlight=True, tab_size=4) - displays = list(make_source_and_bytecode_displays_for_targets(targets=target, theme=theme)) + displays = list(make_source_and_bytecode_displays_for_targets(target_paths=target, theme=theme)) parts = itertools.chain.from_iterable(display.parts for display in displays) total_height = sum(display.height for display in displays) @@ -91,51 +107,81 @@ class Display: height: int -def make_source_and_bytecode_displays_for_targets( - targets: Iterable[str], theme: str -) -> Iterable[Display]: - for func in map(find_function, targets): - yield make_source_and_bytecode_display(func, theme) - - -def find_function(target: str) -> FunctionType: - parts = target.split(".") - - if len(parts) == 1: - try: - module = silent_import(parts[0]) - raise bad_target(target, module, module) - except ModuleNotFoundError as e: - # target was not *actually* a module - raise click.ClickException(str(e)) - - # Walk backwards along the split parts and try to do the import. - # This makes the import go as deep as possible. - for split_point in range(len(parts) - 1, 0, -1): - module_path, target_path = ".".join(parts[:split_point]), ".".join(parts[split_point:]) - - try: - module = obj = silent_import(module_path) - break - except ModuleNotFoundError: - pass - - for target_path_part in target_path.split("."): - try: - obj = getattr(obj, target_path_part) - except AttributeError: - raise click.ClickException( - f"No attribute named {target_path_part!r} found on {type(obj).__name__} {obj!r}." - ) +@dataclass(frozen=True) +class Target: + obj: T_FUNCTION_OR_CLASS_OR_MODULE + path: str + imported_from: Optional[ModuleType] = None + + @cached_property + def module(self): + return self.imported_from or (self.obj if self.is_module else inspect.getmodule(self.obj)) + + @cached_property + def is_module(self) -> bool: + return inspect.ismodule(self.obj) + + @cached_property + def is_class(self) -> bool: + return inspect.isclass(self.obj) + + @cached_property + def is_class_or_module(self) -> bool: + return self.is_class or self.is_module + + @cached_property + def is_function(self) -> bool: + return inspect.isfunction(self.obj) + + def make_display(self, theme: str) -> Display: + return make_source_and_bytecode_display_for_function(self.obj, theme) + + @classmethod + def from_path(cls, path: str) -> "Target": + parts = path.split(".") + + if len(parts) == 1: + try: + module = silent_import(parts[0]) + return cls(obj=module, path=path) + except ModuleNotFoundError as e: + # target was not *actually* a module + raise click.ClickException(str(e)) + + # Walk backwards along the split parts and try to do the import. + # This makes the import go as deep as possible. + for split_point in range(len(parts) - 1, 0, -1): + module_path, obj_path = ".".join(parts[:split_point]), ".".join(parts[split_point:]) + + try: + module = obj = silent_import(module_path) + break + except ModuleNotFoundError: + pass + + for target_path_part in obj_path.split("."): + try: + obj = getattr(obj, target_path_part) + except AttributeError: + raise click.ClickException( + f"No attribute named {target_path_part!r} found on {type(obj).__name__} {obj!r}." + ) - # If the target is a class, display its __init__ method - if inspect.isclass(obj): - obj = obj.__init__ # type: ignore + return cls(obj=obj, path=path, imported_from=module) - if not inspect.isfunction(obj): - raise bad_target(target, obj, module) - return obj +def make_source_and_bytecode_displays_for_targets( + target_paths: Iterable[str], theme: str +) -> Iterator[Display]: + for path in target_paths: + target = Target.from_path(path) + + if target.is_class_or_module: + yield from (t.make_display(theme) for t in find_child_targets(target)) + elif target.is_function: + yield target.make_display(theme) + else: + cannot_be_disassembled(target) def silent_import(module_path: str) -> ModuleType: @@ -152,41 +198,53 @@ def silent_import(module_path: str) -> ModuleType: ) -def bad_target(target: str, obj: Any, module: ModuleType) -> click.ClickException: - possible_targets = find_possible_targets(module) +def cannot_be_disassembled(target: Target): + msg = f"The target {target.path} = {target.obj} is a {type(target.obj).__name__}, which cannot be disassembled. Target a specific function" - msg = f"The target {target} = {obj} is a {type(obj).__name__}, which cannot be disassembled. Target a specific function" + possible_targets = find_child_targets(target) + if len(possible_targets) == 0: + possible_targets = find_child_targets( + Target(obj=target.module, path=".".join(target.path.split(".")[:-1])) + ) if len(possible_targets) == 0: - return click.ClickException(f"{msg}.") + raise click.ClickException(f"{msg}.") else: choice = random.choice(possible_targets) - suggestion = click.style(f"{choice.__module__}.{choice.__qualname__}", bold=True) - return click.ClickException(f"{msg}, like {suggestion}") + suggestion = click.style(choice.path, bold=True) + raise click.ClickException(f"{msg}, like {suggestion}") -def find_possible_targets(obj: ModuleType) -> List[FunctionType]: - return list(_find_possible_targets(obj)) +def find_child_targets(target: Target) -> List[Target]: + return list(_find_child_targets(target, top_module=target.module)) -def _find_possible_targets( - module: ModuleType, top_module: Optional[ModuleType] = None -) -> Iterable[FunctionType]: - for obj in vars(module).values(): - if (inspect.ismodule(module) and inspect.getmodule(obj) != module) or ( - inspect.isclass(module) and inspect.getmodule(module) != top_module - ): - continue +def _find_child_targets(target: Target, top_module: ModuleType) -> Iterable[Target]: + try: + children = vars(target.obj) + except TypeError: # vars() argument must have __dict__ attribute + return - if inspect.isfunction(obj): - yield obj - elif inspect.isclass(obj): - yield from _find_possible_targets(obj, top_module=top_module or module) + for child in children.values(): + if inspect.getmodule(child) != top_module: # Do not go outside of the top module + continue + elif inspect.isclass(child): # Recurse into classes + yield from _find_child_targets( + Target(obj=child, path=f"{target.path}.{child.__name__}"), + top_module=top_module, + ) + elif inspect.isfunction(child): + yield Target(obj=child, path=f"{target.path}.{child.__name__}") -def make_source_and_bytecode_display(function: FunctionType, theme: str) -> Display: +def make_source_and_bytecode_display_for_function(function: FunctionType, theme: str) -> Display: instructions = list(dis.Bytecode(function)) - source_lines, start_line = inspect.getsourcelines(function) + + try: + source_lines, start_line = inspect.getsourcelines(function) + except OSError: # This might happen if the source code is generated + source_lines = ["NO SOURCE CODE FOUND"] + start_line = -1 jump_color_map = find_jump_colors(instructions) @@ -218,14 +276,14 @@ def make_source_and_bytecode_display(function: FunctionType, theme: str) -> Disp def make_title(function, start_line: int) -> Text: - path = Path(inspect.getmodule(function).__file__) + source_file_path = Path(inspect.getmodule(function).__file__) try: - path = path.relative_to(Path.cwd()) + source_file_path = source_file_path.relative_to(Path.cwd()) except ValueError: # path is not under the cwd pass return Text.from_markup( - f"{type(function).__name__} [bold]{function.__module__}.{function.__qualname__}[/bold] from {path}:{start_line}" + f"{type(function).__name__} [bold]{function.__module__}.{function.__qualname__}[/bold] from {source_file_path}:{start_line}" ) @@ -372,14 +430,5 @@ def make_bytecode_block( return grid -def get_own_version() -> str: # pragma: versioned - if sys.version_info < (3, 8): - import importlib_metadata - else: - import importlib.metadata as importlib_metadata - - return importlib_metadata.version("dis_cli") - - if __name__ == "__main__": sys.exit(cli(prog_name="dis")) diff --git a/examples/dis.dis.png b/examples/dis.dis.png index b2eac58..366d85b 100644 Binary files a/examples/dis.dis.png and b/examples/dis.dis.png differ diff --git a/pyproject.toml b/pyproject.toml index 13686de..d4a85d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,6 @@ line_length = 100 [tool.pytest.ini_options] testpaths = ["tests"] console_output_style = "count" + +[tool.pycln] +all = true diff --git a/setup.cfg b/setup.cfg index 5483222..0a6e0b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = dis_cli -version = 0.1.6 +version = 0.2.0 description = A tool to inspect disassembled Python code on the command line. long_description = file: README.md long_description_content_type = text/markdown @@ -31,8 +31,6 @@ install_requires = click>=7 rich>=9 dataclasses>=0.6;python_version<"3.7" - importlib-metadata;python_version<"3.8" - importlib-resources;python_version<"3.7" python_requires = >=3.6 [options.entry_points] @@ -43,6 +41,8 @@ console_scripts = tests = pytest>=6 pytest-cov>=2.10 + importlib-metadata;python_version<"3.8" + importlib-resources;python_version<"3.7" [mypy] files = dis_cli.py, tests/*.py diff --git a/tests/test_cli.py b/tests/test_cli.py index e3da06c..6e74782 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,9 @@ +import sys import textwrap from pathlib import Path import pytest -from dis_cli import get_own_version - def test_smoke(cli): assert cli(["dis.dis"]).exit_code == 0 @@ -86,7 +85,7 @@ def test_module_level_error_is_handled_gracefully(cli, test_dir, filename, extra assert "during import" in result.output -def test_targeting_a_class_redirects_to_init(cli, test_dir, filename): +def test_targeting_a_class_targets_all_of_its_methods(cli, test_dir, filename): source_path = test_dir / f"{filename}.py" source_path.write_text( textwrap.dedent( @@ -94,6 +93,9 @@ def test_targeting_a_class_redirects_to_init(cli, test_dir, filename): class Foo: def __init__(self): print("foobar") + + def method(self): + print("wizbang") """ ) ) @@ -103,6 +105,65 @@ def __init__(self): assert result.exit_code == 0 assert "foobar" in result.output + assert "wizbang" in result.output + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="Non-native dataclasses don't behave the same" +) +def test_can_dis_dataclass(cli, test_dir, filename): + """ + Dataclasses have generated methods with no matching source that we need a special case for. + """ + source_path = test_dir / f"{filename}.py" + source_path.write_text( + textwrap.dedent( + """\ + from dataclasses import dataclass + + @dataclass + class Foo: + attr: int + """ + ) + ) + print(source_path.read_text()) + + result = cli([f"{source_path.stem}.Foo"]) + + assert result.exit_code == 0 + assert "NO SOURCE CODE FOUND" in result.output + + +def test_targeting_a_module_targets_its_members(cli, test_dir, filename): + source_path = test_dir / f"{filename}.py" + source_path.write_text( + textwrap.dedent( + """\ + import itertools + from os.path import join + + def func(): + print("hello") + + class Foo: + def __init__(self): + print("foobar") + + def method(self): + print("wizbang") + """ + ) + ) + print(source_path.read_text()) + + result = cli([f"{source_path.stem}"]) + + assert "combinations" not in result.output # doesn't see imported functions + assert "join" not in result.output # doesn't see imported functions + assert "hello" in result.output + assert "foobar" in result.output + assert "wizbang" in result.output def test_can_target_method(cli, source_path): @@ -121,21 +182,6 @@ def test_module_not_found(cli): assert target in result.output -@pytest.mark.parametrize( - "target", - [ - "click", # top-level module - "click.testing", # submodule - ], -) -def test_gracefully_cannot_disassemble_module(cli, target): - result = cli([target]) - - assert result.exit_code == 1 - assert "cannot be disassembled" in result.output - assert "module" in result.output - - @pytest.mark.parametrize( "target", [ @@ -149,9 +195,3 @@ def test_cannot_be_disassembled(cli, source_path, target): assert result.exit_code == 1 assert "cannot be disassembled" in result.output - - -def test_version(cli): - result = cli(["--version"]) - - assert get_own_version() in result.output diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000..c72605d --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,16 @@ +import sys + + +def get_own_version() -> str: # pragma: versioned + if sys.version_info < (3, 8): + import importlib_metadata + else: + import importlib.metadata as importlib_metadata + + return importlib_metadata.version("dis_cli") + + +def test_version(cli): + result = cli(["--version"]) + + assert get_own_version() in result.output