Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
seanpmorgan committed Jan 5, 2024
1 parent 19e04c1 commit f505f82
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
4 changes: 2 additions & 2 deletions modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def cli() -> None:
)
@cli.command(
help="[Default] Scan a model file or diretory for ability to execute suspicious actions. "
)
) # type: ignore
@click.pass_context
def scan(
ctx: click.Context,
Expand Down Expand Up @@ -132,7 +132,7 @@ def scan(
return 0


@cli.command("create-settings-file", help="Create a modelscan settings file")
@cli.command("create-settings-file", help="Create a modelscan settings file") # type: ignore
@click.option(
"-f", "--force", is_flag=True, help="Overwrite existing settings file if it exists."
)
Expand Down
2 changes: 1 addition & 1 deletion modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __eq__(self, other: Any) -> bool:
and str(self.details.source) == str(other.details.source) # type: ignore[attr-defined]
)

def __repr__(self):
def __repr__(self) -> str:
return str(self.severity) + str(self.details)

def __hash__(self) -> int:
Expand Down
36 changes: 20 additions & 16 deletions modelscan/tools/cli_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import click
from click import Command, Context, HelpFormatter
from typing import List, Optional, Tuple, Any, Union


class DefaultGroup(click.Group):
Expand All @@ -10,72 +12,74 @@ class DefaultGroup(click.Group):
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: object, **kwargs) -> None: # type: ignore
# To resolve as the default command.
if not kwargs.get("ignore_unknown_options", True):
raise ValueError("Default group accepts unknown options")
self.ignore_unknown_options = True
self.default_cmd_name = kwargs.pop("default", None)
self.default_if_no_args = kwargs.pop("default_if_no_args", False)
super(DefaultGroup, self).__init__(*args, **kwargs)
super(DefaultGroup, self).__init__(*args, **kwargs) # type: ignore

def set_default_command(self, command):
def set_default_command(self, command: Command) -> None:
"""Sets a command function as the default command."""
cmd_name = command.name
self.add_command(command)
self.default_cmd_name = cmd_name

def parse_args(self, ctx, args):
def parse_args(self, ctx: Context, args: Any) -> List[str]:
if not args and self.default_if_no_args:
args.insert(0, self.default_cmd_name)
return super(DefaultGroup, self).parse_args(ctx, args)

def get_command(self, ctx, cmd_name):
def get_command(self, ctx: Context, cmd_name: str) -> Optional[Command]:
if cmd_name not in self.commands:
# No command name matched.
ctx.arg0 = cmd_name
ctx.arg0 = cmd_name # type: ignore
cmd_name = self.default_cmd_name
return super(DefaultGroup, self).get_command(ctx, cmd_name)

def resolve_command(self, ctx, args):
def resolve_command(
self, ctx: Context, args: Any
) -> Tuple[Optional[str], Optional[Command], List[str]]:
base = super(DefaultGroup, self)
cmd_name, cmd, args = base.resolve_command(ctx, args)
cmd_name, cmd, args = base.resolve_command(ctx, args) # type: ignore
if hasattr(ctx, "arg0"):
args.insert(0, ctx.arg0)
cmd_name = cmd.name
return cmd_name, cmd, args

def format_commands(self, ctx, formatter):
def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None:
formatter = DefaultCommandFormatter(self, formatter, mark="*")
return super(DefaultGroup, self).format_commands(ctx, formatter)

def command(self, *args, **kwargs):
def command(self, *args: Any, **kwargs: Any) -> Union[Any, Command]:
default = kwargs.pop("default", False)
decorator = super(DefaultGroup, self).command(*args, **kwargs)
if not default:
return decorator

def _decorator(f):
def _decorator(f: Command) -> Union[Any, Command]:
cmd = decorator(f)
self.set_default_command(cmd)
return cmd

return _decorator


class DefaultCommandFormatter(object):
class DefaultCommandFormatter(HelpFormatter):
"""Wraps a formatter to mark a default command."""

def __init__(self, group, formatter, mark="*"):
def __init__(self, group: DefaultGroup, formatter: HelpFormatter, mark: str = "*"):
self.group = group
self.formatter = formatter
self.mark = mark

def __getattr__(self, attr):
def __getattr__(self, attr): # type: ignore
return getattr(self.formatter, attr)

def write_dl(self, rows, *args, **kwargs):
rows_ = []
def write_dl(self, rows, *args, **kwargs): # type: ignore
rows_ = [] # type: ignore
for cmd_name, help in rows:
if cmd_name == self.group.default_cmd_name:
rows_.insert(0, (cmd_name + self.mark, help))
Expand Down

0 comments on commit f505f82

Please sign in to comment.