From b79685335d7ee64a1b00e3a9b50473efeb828d2f Mon Sep 17 00:00:00 2001 From: Anonymous Googler Date: Thu, 15 Sep 2022 09:56:27 -0700 Subject: [PATCH] Add visibility of refex to tool for refactoring jupyter notebooks. PiperOrigin-RevId: 474587955 --- refex/cli.py | 134 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 49 deletions(-) diff --git a/refex/cli.py b/refex/cli.py index c2bf8a5..2e86e4c 100755 --- a/refex/cli.py +++ b/refex/cli.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -:mod:`refex.cli` -================ +""":mod:`refex.cli` ================ Command-line interface to Refex, and extension points to that interface. @@ -26,6 +24,7 @@ from __future__ import division from __future__ import print_function +import abc import argparse import atexit import collections @@ -40,18 +39,17 @@ import tempfile import textwrap import traceback -from typing import Dict, Iterable, List, Optional, Text, Tuple, Union +from typing import Dict, Generic, Iterable, Optional, Text, Tuple, TypeVar, Union, IO from absl import app import attr import colorama import pkg_resources -import six - from refex import formatting from refex import search from refex.fix import find_fixer from refex.python import syntactic_template +import six _IGNORABLE_ERRNO = frozenset([ errno.ENOENT, # file was removed after we went looking @@ -68,6 +66,41 @@ def _shorten_path(path): # given but iteration is to done. _DEFAULT_ITERATION_COUNT = 10 +MetaT = TypeVar('MetaT') + + +@attr.s +class Content(Generic[MetaT]): + data: str = attr.ib() + metadata: MetaT = attr.ib(default=None) + + +class Codec(abc.ABC): + """File codec base.""" + + @abc.abstractmethod + def read(self, f: IO[bytes]) -> Content: + pass + + @abc.abstractmethod + def write(self, f: IO[bytes], content: Content) -> None: + pass + + +@attr.s +class UnicodeCodec(Codec): + """Standard unicode encoded content.""" + + encoding = attr.ib(default='utf-8') + + def read(self, f: IO[bytes]) -> Content: + with io.TextIOWrapper(f, self.encoding) as f: + return Content(data=f.read()) + + def write(self, f: IO[bytes], content: Content) -> None: + with io.TextIOWrapper(f, self.encoding) as f: + f.write(content.data) + @attr.s class RefexRunner(object): @@ -97,8 +130,9 @@ class RefexRunner(object): show_files = attr.ib(default=True) verbose = attr.ib(default=False) max_iterations = attr.ib(default=_DEFAULT_ITERATION_COUNT) + codec = attr.ib(default=UnicodeCodec()) - def read(self, path: str) -> Optional[Text]: + def read(self, path: str) -> Optional[Content]: """Reads in a file and return the resulting content as unicode. Since this is only called from the loop within :meth:`rewrite_files`, @@ -108,11 +142,11 @@ def read(self, path: str) -> Optional[Text]: path: The path to the file. Returns: - An optional unicode string of the file content. + An optional TransformationResult. """ try: - with io.open(path, 'r', encoding='utf-8') as d: - return d.read() + with io.open(path, 'rb') as d: + return self.codec.read(d) except UnicodeDecodeError as e: print('skipped %s: UnicodeDecodeError: %s' % (path, e), file=sys.stderr) return None @@ -142,11 +176,12 @@ def get_matches(self, contents, path): print('skipped %s: %s' % (path, e), file=sys.stderr) return [] - def write(self, path, content, matches): + def write(self, path, result, matches): if not self.dry_run: try: - with io.open(path, 'w', encoding='utf-8') as f: - f.write(formatting.apply_substitutions(content, matches)) + with io.open(path, 'wb') as f: + result.data = formatting.apply_substitutions(result.data, matches) + self.codec.write(f, result) except IOError as e: print('skipped %s: IOError: %s' % (path, e), file=sys.stderr) @@ -175,7 +210,7 @@ def log_changes(self, content, matches, name, renderer): if part: sys.stdout.write(part) sys.stdout.flush() - return has_changes + return has_any_changes def rewrite_files(self, path_pairs): """Main access point for rewriting. @@ -193,13 +228,13 @@ def rewrite_files(self, path_pairs): has_changes = False for read, write in path_pairs: display_name = _shorten_path(write) - content = self.read(read) - if content is not None: + result = self.read(read) + if result is not None: try: - matches = self.get_matches(content, display_name) + matches = self.get_matches(result.data, display_name) except Exception as e: # pylint: disable=broad-except failures[read] = { - 'content': content, + 'content': result.data, 'traceback': traceback.format_exc() } print( @@ -208,8 +243,9 @@ def rewrite_files(self, path_pairs): file=sys.stderr) else: has_changes |= ( - self.log_changes(content, matches, display_name, self.renderer)) - self.write(write, content, matches) + self.log_changes(result.data, matches, display_name, + self.renderer)) + self.write(write, result, matches) if has_changes and self.dry_run: # If there were changes that the user might have wanted to apply, but they # were in dry run mode, print a note for them. @@ -219,7 +255,6 @@ def rewrite_files(self, path_pairs): _BUG_REPORT_URL = 'https://github.com/ssbr/refex/issues/new/choose' - # It was at this point, dear reader, that this programmer wondered if using # argparse was a mistake after all. # @@ -358,13 +393,12 @@ def run_cli(argv, Args: argv: argv parser: An ArgumentParser. - get_runner: called with (parser, options) - returns the runner to use. - get_files: called with (runner, options) - returns the files to examine, as [(in_file, out_file), ...] pairs. - bug_report_url: An URL to present to the user to report bugs. - As the error dump includes source code, corporate organizations may - wish to override this with an internal bug report link for triage. + get_runner: called with (parser, options) returns the runner to use. + get_files: called with (runner, options) returns the files to examine, as + [(in_file, out_file), ...] pairs. + bug_report_url: An URL to present to the user to report bugs. As the error + dump includes source code, corporate organizations may wish to override + this with an internal bug report link for triage. version: The version number to use in bug report logs and --version """ with _report_bug_excepthook(bug_report_url): @@ -547,15 +581,17 @@ def _add_rewriter_arguments(parser): help='Expand passed file paths recursively.') parser.add_argument('--norecursive', action='store_false', dest='recursive') - parser.add_argument('--excludefile', - type=re.compile, - metavar='REGEX', - help='Filenames to exclude (regular expression).') - parser.add_argument('--includefile', - type=re.compile, - metavar='REGEX', - help='Filenames that must match to include' - ' (regular expression).') + parser.add_argument( + '--excludefile', + type=re.compile, + metavar='REGEX', + help='Filenames to exclude (regular expression).') + parser.add_argument( + '--includefile', + type=re.compile, + metavar='REGEX', + help='Filenames that must match to include' + ' (regular expression).') parser.add_argument( '--also', type=search.default_compile_regex, @@ -619,7 +655,8 @@ def _add_rewriter_arguments(parser): action='store_true', dest='print_filename', help='Print the filename in output' - ' (true by default, but disabled by --no-filename).',) + ' (true by default, but disabled by --no-filename).', + ) dry_run_arguments = parser.add_mutually_exclusive_group() dry_run_arguments.add_argument( '--dry-run', @@ -627,12 +664,13 @@ def _add_rewriter_arguments(parser): const=False, dest='in_place', help="Don't write anything to disk. (The default)") - dry_run_arguments.add_argument('--in-place', - '-i', - action='store_const', - const=True, - dest='in_place', - help='Write changes back to disk.') + dry_run_arguments.add_argument( + '--in-place', + '-i', + action='store_const', + const=True, + dest='in_place', + help='Write changes back to disk.') debug_options.add_argument( '--profile-to', @@ -708,8 +746,8 @@ def _parse_options(argv, parser): options, args = _parse_args_leftovers(parser, argv) options.files = [] if options.pattern_or_file is not None: - if (len(options.search_replace) == 1 - and options.search_replace[0].match is None): + if (len(options.search_replace) == 1 and + options.search_replace[0].match is None): options.search_replace[0].match = options.pattern_or_file else: options.files.append(options.pattern_or_file) @@ -831,9 +869,7 @@ def argument_parser(version): ) parser.set_defaults( - rewriter=None, - **{search_replace_dest: [_SearchReplaceArgument()]} - ) + rewriter=None, **{search_replace_dest: [_SearchReplaceArgument()]}) return parser