diff --git a/src/scripts/merge_tools/resolve-import-conflicts b/src/scripts/merge_tools/resolve-import-conflicts index feaee8b83c..aa78cd5231 100755 --- a/src/scripts/merge_tools/resolve-import-conflicts +++ b/src/scripts/merge_tools/resolve-import-conflicts @@ -19,7 +19,7 @@ SCRIPTDIR="$(cd "$(dirname "$0")" && pwd -P)" status=0 for file in "${files[@]}" ; do - if ! "${SCRIPTDIR}"/resolve-import-conflicts-in-file.py "$file" ; then + if ! "${SCRIPTDIR}"/resolve-import-conflicts-in-file.py --file "$file" ; then status=1 fi done diff --git a/src/scripts/merge_tools/resolve-import-conflicts-in-file.py b/src/scripts/merge_tools/resolve-import-conflicts-in-file.py index 36f2750419..319e4fb35f 100755 --- a/src/scripts/merge_tools/resolve-import-conflicts-in-file.py +++ b/src/scripts/merge_tools/resolve-import-conflicts-in-file.py @@ -8,23 +8,12 @@ # TODO: merge both scripts into one. -import os import shutil -import sys +import argparse +from pathlib import Path import tempfile from typing import List, Union, Tuple -if len(sys.argv) != 2: - print( - "resolve-import-conflicts-in-file: Provide exactly one command-line argument." - ) - sys.exit(1) - -filename = sys.argv[1] -with open(filename) as file: - lines = file.readlines() - - def all_import_lines(lines: List[str]) -> bool: """Return true if every line is a Java import line.""" return all(line.startswith("import ") for line in lines) @@ -55,7 +44,7 @@ def merge(base, parent1: List[str], parent2: List[str]) -> Union[List[str], None def looking_at_conflict( # pylint: disable=too-many-return-statements - start_index: int, lines: List[str] + file: Path, start_index: int, lines: List[str] ) -> Union[None, Tuple[List[str], List[str], List[str], int]]: """Tests whether the text starting at line `start_index` is the beginning of a conflict. If not, returns None. @@ -63,7 +52,7 @@ def looking_at_conflict( # pylint: disable=too-many-return-statements where the first 3 elements of the tuple are lists of lines. Args: start_index: an index into `lines`. - lines: all the lines of the file with name `filename`. + lines: all the lines of the file with name `file`. """ if not lines[start_index].startswith("<<<<<<<"): @@ -87,13 +76,13 @@ def looking_at_conflict( # pylint: disable=too-many-return-statements "Starting at line " + str(start_index) + ", did not find ||||||| or ======= in " - + filename + + str(file) ) return None if lines[index].startswith("|||||||"): index = index + 1 if index == num_lines: - print("File ends with |||||||: " + filename) + print("File ends with |||||||: " + str(file)) return None while not lines[index].startswith("======="): base.append(lines[index]) @@ -103,13 +92,13 @@ def looking_at_conflict( # pylint: disable=too-many-return-statements "Starting at line " + str(start_index) + ", did not find ======= in " - + filename + + str(file) ) return None assert lines[index].startswith("=======") index = index + 1 # skip over "=======" line if index == num_lines: - print("File ends with =======: " + filename) + print("File ends with =======: " + str(file)) return None while not lines[index].startswith(">>>>>>>"): parent2.append(lines[index]) @@ -119,7 +108,7 @@ def looking_at_conflict( # pylint: disable=too-many-return-statements "Starting at line " + str(start_index) + ", did not find >>>>>>> in " - + filename + + str(file) ) return None index = index + 1 @@ -127,27 +116,32 @@ def looking_at_conflict( # pylint: disable=too-many-return-statements return (base, parent1, parent2, index - start_index) -## Main starts here. - -with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: - file_len = len(lines) - i = 0 - while i < file_len: - conflict = looking_at_conflict(i, lines) - if conflict is None: - tmp.write(lines[i]) - i = i + 1 - else: - (base, parent1, parent2, num_lines) = conflict - merged = merge(base, parent1, parent2) - if merged is None: +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--file", type=Path) + args = parser.parse_args() + + with open(args.file) as file: + lines = file.readlines() + + with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: + file_len = len(lines) + i = 0 + while i < file_len: + conflict = looking_at_conflict(Path(args.file), i, lines) + if conflict is None: tmp.write(lines[i]) i = i + 1 else: - for line in merged: - tmp.write(line) - i = i + num_lines - - tmp.close() - shutil.copy(tmp.name, filename) - os.unlink(tmp.name) + (base, parent1, parent2, num_lines) = conflict + merged = merge(base, parent1, parent2) + if merged is None: + tmp.write(lines[i]) + i = i + 1 + else: + for line in merged: + tmp.write(line) + i = i + num_lines + + # Copying the file back to the original location + shutil.copyfile(tmp.name, args.file)