Skip to content

Commit

Permalink
try fix github CI - run isort/black
Browse files Browse the repository at this point in the history
  • Loading branch information
bckohan committed Jan 27, 2024
1 parent 2405ccc commit cf08161
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 132 deletions.
37 changes: 17 additions & 20 deletions django_typer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,7 @@
__copyright__ = "Copyright 2023 Brian Kohan"


__all__ = [
"TyperCommand",
"Context",
"initialize",
"command",
"group",
"get_command"
]
__all__ = ["TyperCommand", "Context", "initialize", "command", "group", "get_command"]

"""
TODO
Expand Down Expand Up @@ -93,6 +86,7 @@
# except ImportError:
# pass


def traceback_config():
"""
Fetch the rich traceback installation parameters from our settings. By default
Expand Down Expand Up @@ -332,23 +326,22 @@ def common_params(self):
return [
param
for param in _get_common_params()
if param.name not in (self.django_command.suppressed_base_arguments or [])
if param.name
not in (self.django_command.suppressed_base_arguments or [])
]
return super().common_params()


class TyperGroupWrapper(DjangoAdapterMixin, CoreTyperGroup):
def common_params(self):
if (
(
hasattr(self, "django_command") and
self.django_command._has_callback
) or getattr(self, "common_init", False)
):
hasattr(self, "django_command") and self.django_command._has_callback
) or getattr(self, "common_init", False):
return [
param
for param in _get_common_params()
if param.name not in (self.django_command.suppressed_base_arguments or [])
if param.name
not in (self.django_command.suppressed_base_arguments or [])
]
return super().common_params()

Expand Down Expand Up @@ -701,7 +694,7 @@ def handle(self, *args, **options):
"_handle": attrs.pop("handle", None),
**attrs,
"handle": handle,
"typer_app": typer_app
"typer_app": typer_app,
}

return super().__new__(mcs, name, bases, attrs)
Expand All @@ -713,9 +706,9 @@ def __init__(cls, name, bases, attrs, **kwargs):
if cls.typer_app is not None:
cls.typer_app.info.name = cls.__module__.rsplit(".", maxsplit=1)[-1]
cls.suppressed_base_arguments = {
arg.lstrip('--').replace('-', '_')
arg.lstrip("--").replace("-", "_")
for arg in cls.suppressed_base_arguments
} # per django docs - allow these to be specified by either the option or param name
} # per django docs - allow these to be specified by either the option or param name

def get_ctor(attr):
return getattr(
Expand Down Expand Up @@ -763,7 +756,11 @@ def get_ctor(attr):
cls=type(
"_AdaptedCallback",
(TyperGroupWrapper,),
{"django_command": cls, "callback_is_method": False, "common_init": True},
{
"django_command": cls,
"callback_is_method": False,
"common_init": True,
},
)
)(lambda: None)

Expand Down Expand Up @@ -869,7 +866,7 @@ class Command(TyperCommand, attach='app_label.command_name.subcommand1.subcomman
# we do not use verbosity because the base command does not do anything with it
# if users want to use a verbosity flag like the base django command adds
# they can use the type from django_typer.types.Verbosity
suppressed_base_arguments: t.Optional[t.Iterable[str]] = {'verbosity'}
suppressed_base_arguments: t.Optional[t.Iterable[str]] = {"verbosity"}

class CommandNode:
name: str
Expand Down
134 changes: 62 additions & 72 deletions django_typer/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,21 @@

import typing as t
from types import MethodType

from click import Context, Parameter
from click.shell_completion import CompletionItem
from django.db.models import Q, Model, Max
from django.db.models import (
IntegerField,
FloatField,
DecimalField,
from django.apps import apps
from django.db.models import ( # TODO:; GenericIPAddressField,; TimeField,; DateField,; DateTimeField,; DurationField,; FilePathField,; FileField
CharField,
DecimalField,
FloatField,
IntegerField,
Max,
Model,
Q,
TextField,
UUIDField,
# TODO:
# GenericIPAddressField,
# TimeField,
# DateField,
# DateTimeField,
# DurationField,
# FilePathField,
# FileField
)
from django.apps import apps


class ModelObjectCompleter:
Expand All @@ -40,9 +35,9 @@ class ModelObjectCompleter:
:param lookup_field: The name of the model field to use for lookup.
:param help_field: The name of the model field to use for help text or None if
no help text should be provided.
:param query: A callable that accepts the completer object instance, the click
context, the click parameter, and the incomplete string and returns a Q
object to use for filtering the queryset. The default query will use the
:param query: A callable that accepts the completer object instance, the click
context, the click parameter, and the incomplete string and returns a Q
object to use for filtering the queryset. The default query will use the
relevant class methods depending on the lookup field class. See the
query methods for details.
:param limit: The maximum number of completion items to return. If None, all
Expand All @@ -56,20 +51,18 @@ class ModelObjectCompleter:
given for the parameter on the command line.
"""

QueryBuilder = t.Callable[['ModelObjectCompleter', Context, Parameter, str], Q]
QueryBuilder = t.Callable[["ModelObjectCompleter", Context, Parameter, str], Q]

model_cls: t.Type[Model]
lookup_field: str = 'id'
lookup_field: str = "id"
help_field: t.Optional[str] = None
query: QueryBuilder
limit: t.Optional[int] = 50
case_insensitive: bool = True
distinct: bool = True

def default_query(self,
context: Context,
parameter: Parameter,
incomplete: str

def default_query(
self, context: Context, parameter: Parameter, incomplete: str
) -> Q:
"""
The default completion query builder. This method will route to the
Expand All @@ -94,14 +87,9 @@ def default_query(self,
return self.uuid_query(context, parameter, incomplete)
elif issubclass(field.__class__, (FloatField, DecimalField)):
return self.float_query(context, parameter, incomplete)
raise ValueError(f'Unsupported lookup field class: {field.__class__.__name__}')

def int_query(
self,
context: Context,
parameter: Parameter,
incomplete: str
) -> Q:
raise ValueError(f"Unsupported lookup field class: {field.__class__.__name__}")

def int_query(self, context: Context, parameter: Parameter, incomplete: str) -> Q:
"""
The default completion query builder for integer fields. This method will
return a Q object that will match any value that starts with the incomplete
Expand All @@ -116,20 +104,17 @@ def int_query(
:raises TypeError: If the incomplete string is not a valid integer.
"""
lower = int(incomplete)
upper = lower+1
max_val = self.model_cls.objects.aggregate(Max(self.lookup_field))['id__max']
qry = Q(**{f'{self.lookup_field}': lower})
while (lower:=lower*10) <= max_val:
upper = lower + 1
max_val = self.model_cls.objects.aggregate(Max(self.lookup_field))["id__max"]
qry = Q(**{f"{self.lookup_field}": lower})
while (lower := lower * 10) <= max_val:
upper *= 10
qry |= Q(**{f'{self.lookup_field}__gte': lower}) & Q(**{f'{self.lookup_field}__lt': upper})
qry |= Q(**{f"{self.lookup_field}__gte": lower}) & Q(
**{f"{self.lookup_field}__lt": upper}
)
return qry

def float_query(
self,
context: Context,
parameter: Parameter,
incomplete: str
):

def float_query(self, context: Context, parameter: Parameter, incomplete: str):
"""
The default completion query builder for float fields. This method will
return a Q object that will match any value that starts with the incomplete
Expand All @@ -143,12 +128,14 @@ def float_query(
:raises ValueError: If the incomplete string is not a valid float.
:raises TypeError: If the incomplete string is not a valid float.
"""
if '.' not in incomplete:
if "." not in incomplete:
return self.int_query(context, parameter, incomplete)
incomplete = incomplete.rstrip('0')
incomplete = incomplete.rstrip("0")
lower = float(incomplete)
upper = lower + float(f'0.{"0"*(len(incomplete)-incomplete.index(".")-2)}1')
return Q(**{f'{self.lookup_field}__gte': lower}) & Q(**{f'{self.lookup_field}__lt': upper})
return Q(**{f"{self.lookup_field}__gte": lower}) & Q(
**{f"{self.lookup_field}__lt": upper}
)

def text_query(self, context: Context, parameter: Parameter, incomplete: str) -> Q:
"""
Expand All @@ -162,16 +149,16 @@ def text_query(self, context: Context, parameter: Parameter, incomplete: str) ->
:return: A Q object to use for filtering the queryset.
"""
if self.case_insensitive:
return Q(**{f'{self.lookup_field}__istartswith': incomplete})
return Q(**{f'{self.lookup_field}__startswith': incomplete})
return Q(**{f"{self.lookup_field}__istartswith": incomplete})
return Q(**{f"{self.lookup_field}__startswith": incomplete})

def uuid_query(self, context: Context, parameter: Parameter, incomplete: str) -> Q:
"""
The default completion query builder for UUID fields. This method will
return a Q object that will match any value that starts with the incomplete
string. The incomplete string will be stripped of all non-alphanumeric
characters and padded with zeros to 32 characters. For example, if the
incomplete string is "a", the query will match
incomplete string is "a", the query will match
a0000000-0000-0000-0000-000000000000 to affffffff-ffff-ffff-ffff-ffffffffffff.
:param context: The click context.
Expand All @@ -181,25 +168,27 @@ def uuid_query(self, context: Context, parameter: Parameter, incomplete: str) ->
:raises ValueError: If the incomplete string is too long or contains invalid
UUID characters. Anything other than (0-9a-fA-F).
"""
uuid = ''
uuid = ""
for char in incomplete:
if char.isalnum():
uuid += char
if len(uuid) > 32:
raise ValueError(f'Too many UUID characters: {incomplete}')
min_uuid = uuid + '0'*(32-len(uuid))
max_uuid = uuid + 'f'*(32-len(uuid))
return Q(**{f'{self.lookup_field}__gte': min_uuid}) & Q(**{f'{self.lookup_field}__lte': max_uuid})
raise ValueError(f"Too many UUID characters: {incomplete}")
min_uuid = uuid + "0" * (32 - len(uuid))
max_uuid = uuid + "f" * (32 - len(uuid))
return Q(**{f"{self.lookup_field}__gte": min_uuid}) & Q(
**{f"{self.lookup_field}__lte": max_uuid}
)

def __init__(
self,
model_cls: t.Type[Model],
lookup_field: str = lookup_field,
help_field: t.Optional[str] = help_field,
query: QueryBuilder = default_query,
limit: t.Optional[int] = limit,
case_insensitive: bool = case_insensitive,
distinct: bool = distinct
self,
model_cls: t.Type[Model],
lookup_field: str = lookup_field,
help_field: t.Optional[str] = help_field,
query: QueryBuilder = default_query,
limit: t.Optional[int] = limit,
case_insensitive: bool = case_insensitive,
distinct: bool = distinct,
):
self.model_cls = model_cls
self.lookup_field = lookup_field
Expand All @@ -210,10 +199,7 @@ def __init__(
self.distinct = distinct

def __call__(
self,
context: Context,
parameter: Parameter,
incomplete: str
self, context: Context, parameter: Parameter, incomplete: str
) -> t.Union[t.List[CompletionItem], t.List[str]]:
"""
The completer method. This method will return a list of CompletionItem
Expand All @@ -228,7 +214,7 @@ def __call__(
:param incomplete: The incomplete string.
:return: A list of CompletionItem objects.
"""

completion_qry = Q()

if incomplete:
Expand All @@ -240,8 +226,11 @@ def __call__(
return [
CompletionItem(
value=str(getattr(obj, self.lookup_field)),
help=getattr(obj, self.help_field, None) if self.help_field else ''
) for obj in self.model_cls.objects.filter(completion_qry).distinct()[0:self.limit]
help=getattr(obj, self.help_field, None) if self.help_field else "",
)
for obj in self.model_cls.objects.filter(completion_qry).distinct()[
0 : self.limit
]
]


Expand All @@ -257,6 +246,7 @@ def complete_app_label(ctx: Context, param: Parameter, incomplete: str):
"""
present = [app.label for app in (ctx.params.get(param.name) or [])]
return [
app.label for app in apps.get_app_configs()
app.label
for app in apps.get_app_configs()
if app.label.lower().startswith(incomplete.lower()) and app.label not in present
]
12 changes: 9 additions & 3 deletions django_typer/management/commands/shellcompletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ class Command(TyperCommand):
requires_migrations_checks = False

# remove unnecessary django command base parameters - these just clutter the help
suppressed_base_arguments = {'version', 'skip_checks', 'no_color', 'force_color', 'verbosity'}
suppressed_base_arguments = {
"version",
"skip_checks",
"no_color",
"force_color",
"verbosity",
}

_shell: Shells

Expand Down Expand Up @@ -377,7 +383,7 @@ def complete(
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(command)
if command[-1].isspace():
cwords.append('')
cwords.append("")
# allow users to not specify the manage script, but allow for it
# if they do by lopping it off - same behavior as upstream classes
try:
Expand Down Expand Up @@ -411,7 +417,7 @@ def get_completions(self, args, incomplete):
complete_var=self.COMPLETE_VAR,
).get_completion_args()

with open('test.txt', 'w') as f:
with open("test.txt", "w") as f:
f.write(f'{args}\n"{incomplete}"')

def call_fallback(fb):
Expand Down
Loading

0 comments on commit cf08161

Please sign in to comment.