Skip to content

Commit

Permalink
Port from argparse to absl.flags.argparse_flags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 548107146
  • Loading branch information
lukegb authored and copybara-github committed Jul 14, 2023
1 parent 072f365 commit 1657b45
Showing 1 changed file with 27 additions and 44 deletions.
71 changes: 27 additions & 44 deletions refex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
import tempfile
import textwrap
import traceback
from typing import Dict, Generic, Iterable, Optional, Text, Tuple, TypeVar, Union, IO
from typing import Callable, Dict, Generic, IO, Iterable, Optional, Text, Tuple, TypeVar, Union

from absl import app
from absl.flags import argparse_flags
import attr
import colorama
import pkg_resources
Expand Down Expand Up @@ -334,21 +335,6 @@ def __call__(self, parser, namespace, value, option_string=None):
old_sub[name] = pattern


def _absl_run_separate_argv(main_func, main_argv, absl_argv):
"""Runs main via absl.app.run(), passing different argv to main and to absl.
Args:
main_func: A function main(main_argv).
main_argv: The argv to pass to main.
absl_argv: The argv to pass to absl.
"""

def absl_main(unused_argv):
return main_func(main_argv)

app.run(absl_main, argv=absl_argv)


def run(runner: RefexRunner,
files: Iterable[Union[str, Tuple[str, str]]],
bug_report_url: Text,
Expand Down Expand Up @@ -378,17 +364,23 @@ def run(runner: RefexRunner,
pass


def run_cli(argv,
parser,
get_runner,
get_files,
bug_report_url=_BUG_REPORT_URL,
version='<unspecified>'):
def run_cli(
argv: Iterable[str],
parser_factory: Callable[[], argparse.ArgumentParser],
get_runner: Callable[
[argparse.ArgumentParser, argparse.Namespace], RefexRunner
],
get_files: Callable[
[RefexRunner, argparse.Namespace], Iterable[tuple[str, str]]
],
bug_report_url: str = _BUG_REPORT_URL,
version: str = '<unspecified>',
) -> None:
"""Creates a runner from command-line arguments, and executes it.
Args:
argv: argv
parser: An ArgumentParser.
parser_factory: a callable that generates an ArgumentParser.
get_runner: called with (parser, options) returns the runner to use.
get_files: called with (runner, options) returns the files to examine, as
[(in_file, out_file), ...] pairs.
Expand All @@ -398,21 +390,15 @@ def run_cli(argv,
version: The version number to use in bug report logs and --version
"""
with _report_bug_excepthook(bug_report_url):
_add_rewriter_arguments(parser)

# For legacy reasons, refex uses argparse. This isn't very easily
# compatible with using app.run() -- what if someone defines a flag that
# conflicts with an argparse flag?
# Nonetheless, we want to use app.run to allow interop with absl-using
# libraries. So we process argparse flags first, and then give app.run() a
# fake argv of zilch.
#
# In the future, one could imagine providing a --absl-flag= option in
# argparse to let one override absl-flag values, but for now let's just
# ignore it.

def main_for_absl(argv):

def parse_flags(argv):
parser = parser_factory()
_add_rewriter_arguments(parser)
options = _parse_options(argv[1:], parser)
return parser, options

def _main(args):
parser, options = args

def _run():
"""A wrapper function for profiler.runcall."""
Expand All @@ -428,7 +414,7 @@ def _run():
_run()

try:
_absl_run_separate_argv(main_for_absl, argv, [argv[0]])
app.run(_main, argv=list(argv), flags_parser=parse_flags)
except KeyboardInterrupt:
pass

Expand Down Expand Up @@ -673,7 +659,6 @@ def _add_rewriter_arguments(parser):
metavar='FILE',
help='Profile main() and write results to disk at FILE.')
debug_options.add_argument(
'-v',
'--verbose',
action='store_const',
const=True,
Expand Down Expand Up @@ -757,9 +742,9 @@ def _parse_options(argv, parser):
return options


def argument_parser(version):
def argument_parser():
"""Creates an :class:`argparse.ArgumentParser` for the refex CLI."""
parser = argparse.ArgumentParser(
parser = argparse_flags.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=textwrap.dedent("""\
Syntactically aware search/replace."""),
Expand All @@ -782,8 +767,6 @@ def argument_parser(version):
allow_abbrev=False,
)

parser.add_argument('--version', action='version', version=version)

match_options = parser.add_argument_group(
'match arguments',
'Arguments for use when performing search-replace (when passing the '
Expand Down Expand Up @@ -949,7 +932,7 @@ def main(argv=None, bug_report_url=_BUG_REPORT_URL, version=None):
)
run_cli(
argv,
argument_parser(version=version),
argument_parser,
runner_from_options,
files_from_options,
bug_report_url=bug_report_url,
Expand Down

0 comments on commit 1657b45

Please sign in to comment.