Skip to content

Commit

Permalink
fix: runner opts processing (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
ocervell authored Nov 11, 2024
1 parent 396f68a commit d788e9d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 58 deletions.
3 changes: 0 additions & 3 deletions secator/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,20 @@ def break_task(task, task_opts, targets, results=[], chunk_size=1):

@app.task(bind=True)
def run_task(self, args=[], kwargs={}):
debug(f'Received task with args {args} and kwargs {kwargs}', sub="celery.run", verbose=True)
kwargs['context']['celery_id'] = self.request.id
task = Task(*args, **kwargs)
task.run()


@app.task(bind=True)
def run_workflow(self, args=[], kwargs={}):
debug(f'Received workflow with args {args} and kwargs {kwargs}', sub="celery.run", verbose=True)
kwargs['context']['celery_id'] = self.request.id
workflow = Workflow(*args, **kwargs)
workflow.run()


@app.task(bind=True)
def run_scan(self, args=[], kwargs={}):
debug(f'Received scan with args {args} and kwargs {kwargs}', sub="celery.run", verbose=True)
if 'context' not in kwargs:
kwargs['context'] = {}
kwargs['context']['celery_id'] = self.request.id
Expand Down
3 changes: 3 additions & 0 deletions secator/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ def func(ctx, **opts):
_get_rich_console().print('[bold red]Missing `redis` addon: please run `secator install addons redis`[/].')
sys.exit(1)

from secator.utils import debug
debug('Run options', obj=opts, sub='cli')

# Set run options
opts.update({
'print_cmd': True,
Expand Down
51 changes: 31 additions & 20 deletions secator/runners/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from secator.output_types import Error, Target, Stat
from secator.runners import Runner
from secator.template import TemplateLoader
from secator.utils import debug


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -621,8 +622,7 @@ def _process_opts(
opt_value_map={},
opt_prefix='-',
command_name=None):
"""Process a dict of options using a config, option key map / value map
and option character like '-' or '--'.
"""Process a dict of options using a config, option key map / value map and option character like '-' or '--'.
Args:
opts (dict): Command options as input on the CLI.
Expand All @@ -634,6 +634,7 @@ def _process_opts(
"""
opts_str = ''
for opt_name, opt_conf in opts_conf.items():
debug('before get_opt_value', obj={'name': opt_name, 'conf': opt_conf}, obj_after=False, sub='command.options', verbose=True) # noqa: E501

# Get opt value
default_val = opt_conf.get('default')
Expand All @@ -644,25 +645,30 @@ def _process_opts(
opt_prefix=command_name,
default=default_val)

debug('after get_opt_value', obj={'name': opt_name, 'value': opt_val, 'conf': opt_conf}, obj_after=False, sub='command.options', verbose=True) # noqa: E501

# Skip option if value is falsy
if opt_val in [None, False, []]:
# logger.debug(f'Option {opt_name} was passed but is falsy. Skipping.')
debug('skipped (falsy)', obj={'name': opt_name, 'value': opt_val}, obj_after=False, sub='command.options', verbose=True) # noqa: E501
continue

# Convert opt value to expected command opt value
mapped_opt_val = opt_value_map.get(opt_name)
if callable(mapped_opt_val):
opt_val = mapped_opt_val(opt_val)
elif mapped_opt_val:
opt_val = mapped_opt_val
if mapped_opt_val:
if callable(mapped_opt_val):
opt_val = mapped_opt_val(opt_val)
else:
opt_val = mapped_opt_val

# Convert opt name to expected command opt name
mapped_opt_name = opt_key_map.get(opt_name)
if mapped_opt_name == OPT_NOT_SUPPORTED:
# logger.debug(f'Option {opt_name} was passed but is unsupported. Skipping.')
continue
elif mapped_opt_name is not None:
opt_name = mapped_opt_name
if mapped_opt_name is not None:
if mapped_opt_name == OPT_NOT_SUPPORTED:
debug('skipped (unsupported)', obj={'name': opt_name, 'value': opt_val}, sub='command.options', verbose=True) # noqa: E501
continue
else:
opt_name = mapped_opt_name
debug('mapped key / value', obj={'name': opt_name, 'value': opt_val}, obj_after=False, sub='command.options', verbose=True) # noqa: E501

# Avoid shell injections and detect opt prefix
opt_name = str(opt_name).split(' ')[0] # avoid cmd injection
Expand All @@ -682,6 +688,7 @@ def _process_opts(
if shlex_quote:
opt_val = shlex.quote(str(opt_val))
opts_str += f' {opt_val}'
debug('final', obj={'name': opt_name, 'value': opt_val}, sub='command.options', obj_after=False, verbose=True)

return opts_str.strip()

Expand Down Expand Up @@ -715,17 +722,21 @@ def _get_opt_default(opt_name, opts_conf):
@staticmethod
def _get_opt_value(opts, opt_name, opts_conf={}, opt_prefix='', default=None):
default = default or Command._get_opt_default(opt_name, opts_conf)
aliases = [
opts.get(f'{opt_prefix}_{opt_name}'),
opts.get(f'{opt_prefix}.{opt_name}'),
opts.get(opt_name),
opt_names = [
f'{opt_prefix}.{opt_name}',
f'{opt_prefix}_{opt_name}',
opt_name,
]
alias = [conf.get('short') for _, conf in opts_conf.items() if conf.get('short') in opts]
opt_values = [opts.get(o) for o in opt_names]
alias = [conf.get('short') for _, conf in opts_conf.items() if conf.get('short') in opts and _ == opt_name]
if alias:
aliases.append(opts.get(alias[0]))
if OPT_NOT_SUPPORTED in aliases:
opt_values.append(opts.get(alias[0]))
if OPT_NOT_SUPPORTED in opt_values:
debug('skipped (unsupported)', obj={'name': opt_name}, obj_after=False, sub='command.options', verbose=True)
return None
return next((v for v in aliases if v is not None), default)
value = next((v for v in opt_values if v is not None), default)
debug('got opt value', obj={'name': opt_name, 'value': value, 'aliases': opt_names, 'values': opt_values}, obj_after=False, sub='command.options', verbose=True) # noqa: E501
return value

def _build_cmd(self):
"""Build command string."""
Expand Down
70 changes: 35 additions & 35 deletions secator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,43 +353,43 @@ def rich_to_ansi(text):
return capture.get()


def debug(msg, sub='', id='', obj=None, lazy=None, obj_after=True, obj_breaklines=False, verbose=False):
"""Print debug log if DEBUG >= level."""
debug_comp_empty = DEBUG_COMPONENT == [""] or not DEBUG_COMPONENT
if debug_comp_empty:
return
def format_object(obj, obj_breaklines=False):
"""Format the debug object for printing."""
sep = '\n ' if obj_breaklines else ', '
if isinstance(obj, dict):
return sep.join(f'[dim cyan]{k}[/] [dim yellow]->[/] [dim green]{v}[/]' for k, v in obj.items() if v is not None) # noqa: E501
elif isinstance(obj, list):
return f'[dim green]{sep.join(obj)}[/]'
return ''

if sub and verbose and not any(sub == s for s in DEBUG_COMPONENT):
sub = f'debug.{sub}'

if not any(sub.startswith(s) for s in DEBUG_COMPONENT):
return

if lazy:
msg = lazy(msg)

s = ''
if sub:
s += f'[dim yellow4]{sub:13s}[/] '
obj_str = ''
if obj:
sep = ', '
if obj_breaklines:
obj_str += '\n '
sep = '\n '
if isinstance(obj, dict):
obj_str += sep.join(f'[dim blue]{k}[/] [dim yellow]->[/] [dim green]{v}[/]' for k, v in obj.items() if v is not None)
elif isinstance(obj, list):
obj_str += f'[dim green]{sep.join(obj)}[/]'
if obj_str and not obj_after:
s = f'{s} {obj_str} '
s += f'[dim yellow]{msg}[/] '
if obj_str and obj_after:
s = f'{s}: {obj_str}'
if id:
s += f' [italic dim gray11]\[{id}][/] '
s = rich_to_ansi(f'[dim red]🐛 {s}[/]')
print(s)
def debug(msg, sub='', id='', obj=None, lazy=None, obj_after=True, obj_breaklines=False, verbose=False):
"""Print debug log if DEBUG >= level."""
if not DEBUG_COMPONENT or DEBUG_COMPONENT == [""]:
return

if sub:
if verbose and sub not in DEBUG_COMPONENT:
sub = f'debug.{sub}'
if not any(sub.startswith(s) for s in DEBUG_COMPONENT):
return

if lazy:
msg = lazy(msg)

formatted_msg = f'[dim yellow4]{sub:13s}[/] ' if sub else ''
obj_str = format_object(obj, obj_breaklines) if obj else ''

# Constructing the message string based on object position
if obj_str and not obj_after:
formatted_msg += f'{obj_str} '
formatted_msg += f'[dim yellow]{msg}[/]'
if obj_str and obj_after:
formatted_msg += f': {obj_str}'
if id:
formatted_msg += f' [italic dim gray11]\[{id}][/]'

console.print(f'[dim red]🐛 {formatted_msg}[/]', style='red')


def escape_mongodb_url(url):
Expand Down

0 comments on commit d788e9d

Please sign in to comment.