diff --git a/bibtexautocomplete/core/main.py b/bibtexautocomplete/core/main.py index e3e6c29..ba7fdf8 100644 --- a/bibtexautocomplete/core/main.py +++ b/bibtexautocomplete/core/main.py @@ -2,7 +2,7 @@ from pathlib import Path from sys import stdout from tempfile import mkstemp -from typing import Any, Callable, Container, List, NoReturn, Optional, Set +from typing import Any, Callable, Container, List, Optional, Set from bibtexparser.bibdatabase import UndefinedString @@ -42,20 +42,23 @@ pass -def conflict(parser: MyParser, prefix: str, option1: str, option2: str) -> NoReturn: - parser.error( - "{StBold}Conflicting options:\n{Reset}" - + " Specified both " - + prefix - + "{FgYellow}" - + option1 - + "{Reset} and a {FgYellow}" - + option2 - + "{Reset} option." - ) +def conflict(parser: MyParser, prefix: str, option1: str, option2: str) -> int: + try: + parser.error( + "{StBold}Conflicting options:\n{Reset}" + + " Specified both " + + prefix + + "{FgYellow}" + + option1 + + "{Reset} and a {FgYellow}" + + option2 + + "{Reset} option." + ) + except ValueError: + return 2 -def main(argv: Optional[List[str]] = None) -> None: +def main(argv: Optional[List[str]] = None) -> int: """The main function of bibtexautocomplete Takes an argv like List as argument, if none, uses sys.argv @@ -63,10 +66,13 @@ def main(argv: Optional[List[str]] = None) -> None: parser = make_parser() if parser_autocomplete is not None: parser_autocomplete(parser) - if argv is None: - args = parser.parse_args() - else: - args = parser.parse_args(argv) + try: + if argv is None: + args = parser.parse_args() + else: + args = parser.parse_args(argv) + except ValueError: + return 2 ANSICodes.use_ansi = stdout.isatty() and not args.no_color @@ -85,14 +91,14 @@ def main(argv: Optional[List[str]] = None) -> None: PREFIX=FIELD_PREFIX, ) ) - return + return 0 if args.version: print( "{NAME} version {VERSION} ({VERSION_DATE})".format( NAME=SCRIPT_NAME, VERSION=VERSION_STR, VERSION_DATE=VERSION_DATE ) ) - return + return 0 if args.silent: args.verbose = -args.silent @@ -111,7 +117,7 @@ def main(argv: Optional[List[str]] = None) -> None: lookups = OnlyExclude[str].from_nonempty(args.only_query, args.dont_query).filter(LOOKUPS, lambda x: x.name) if args.only_query != [] and args.dont_query != []: - conflict(parser, "a ", "-q/--only-query", "-Q/--dont-query") + return conflict(parser, "a ", "-q/--only-query", "-Q/--dont-query") if args.only_query != []: # remove duplicate from list args.only_query, dups = list_unduplicate(args.only_query) @@ -122,11 +128,11 @@ def main(argv: Optional[List[str]] = None) -> None: fields = OnlyExclude[FieldType].from_nonempty(args.only_complete, args.dont_complete) if args.only_complete != [] and args.dont_complete != []: - conflict(parser, "a ", "-c/--only-complete", "-C/--dont-complete") + return conflict(parser, "a ", "-c/--only-complete", "-C/--dont-complete") entries = OnlyExclude[str].from_nonempty(args.only_entry, args.exclude_entry) if args.only_entry != [] and args.exclude_entry != []: - conflict(parser, "a ", "-e/--only-entry", "-E/--exclude-entry") + return conflict(parser, "a ", "-e/--only-entry", "-E/--exclude-entry") if args.protect_all_uppercase: fields_to_protect_uppercase: Container[str] = FieldNamesSet @@ -135,11 +141,11 @@ def main(argv: Optional[List[str]] = None) -> None: fields_to_protect_proto.default = False fields_to_protect_uppercase = fields_to_protect_proto if args.protect_all_uppercase and args.protect_uppercase != []: - conflict(parser, "", "--fpa/--protect-all-uppercase", "--fp/--protect-uppercase") + return conflict(parser, "", "--fpa/--protect-all-uppercase", "--fp/--protect-uppercase") if args.protect_all_uppercase and args.dont_protect_uppercase != []: - conflict(parser, "", "--fpa/--protect-all-uppercase", "--FP/--dont-protect-uppercase") + return conflict(parser, "", "--fpa/--protect-all-uppercase", "--FP/--dont-protect-uppercase") if args.protect_uppercase != [] and args.dont_protect_uppercase != []: - conflict(parser, "a ", "--fp/--protect-uppercase", "--FP/--dont-protect-uppercase") + return conflict(parser, "a ", "--fp/--protect-uppercase", "--FP/--dont-protect-uppercase") if args.force_overwrite: fields_to_overwrite: Set[FieldType] = FieldNamesSet @@ -148,20 +154,23 @@ def main(argv: Optional[List[str]] = None) -> None: overwrite.default = False fields_to_overwrite = set(overwrite.filter(FieldNamesSet, lambda x: x)) if args.force_overwrite and args.overwrite != []: - conflict(parser, "", "-f/--force-overwrite", "-w/--overwrite") + return conflict(parser, "", "-f/--force-overwrite", "-w/--overwrite") if args.force_overwrite and args.dont_overwrite != []: - conflict(parser, "", "-f/--force-overwrite", "-W/--dont-overwrite") + return conflict(parser, "", "-f/--force-overwrite", "-W/--dont-overwrite") if args.overwrite != [] and args.dont_overwrite != []: - conflict(parser, "a ", "-w/--overwrite", "-W/--dont-overwrite") + return conflict(parser, "a ", "-w/--overwrite", "-W/--dont-overwrite") if args.diff and args.inplace: - parser.error( - "Cannot use {FgYellow}-D/--diff{Reset} flag and {FgYellow}-i/--inplace{Reset} flag " - "simultaneously, as there\n" - " is a big risk of deleting data.\n" - " If that is truly what you want to do, specify the output file explictly\n" - " with {FgYellow}-o / --output {FgGreen}{Reset}." - ) + try: + parser.error( + "Cannot use {FgYellow}-D/--diff{Reset} flag and {FgYellow}-i/--inplace{Reset} flag " + "simultaneously, as there\n" + " is a big risk of deleting data.\n" + " If that is truly what you want to do, specify the output file explictly\n" + " with {FgYellow}-o / --output {FgGreen}{Reset}." + ) + except ValueError: + return 2 try: completer = BibtexAutocomplete( @@ -200,7 +209,7 @@ def main(argv: Optional[List[str]] = None) -> None: logger.warn("Interrupted") if completer.position == 0: logger.info("No entries were completed") - return None + return 5 _, tempfile = mkstemp(suffix=".btac.bib", prefix="btac-interrupt-", text=True) logger.header("Dumping data") with open(tempfile, "w") as file: @@ -226,12 +235,14 @@ def main(argv: Optional[List[str]] = None) -> None: if i == completer.position: logger.info("Only completed entries up to and including '{}'.\n".format(entry.get("ID", ""))) break_next = True - + return 5 except KeyboardInterrupt: logger.warn("Interrupted x2") + return 7 except ValueError: - exit(2) + return 2 except UndefinedString: - exit(1) + return 1 except (IOError, UnicodeDecodeError): - exit(1) + return 1 + return 0 diff --git a/bibtexautocomplete/core/parser.py b/bibtexautocomplete/core/parser.py index f9d0d1c..1fdd57b 100644 --- a/bibtexautocomplete/core/parser.py +++ b/bibtexautocomplete/core/parser.py @@ -10,6 +10,7 @@ from typing import IO, Iterable, List, NoReturn, Optional, TypeVar from ..bibtex.constants import FieldNamesSet +from ..utils.ansi import ANSICodes from ..utils.constants import BTAC_FILENAME, CONNECTION_TIMEOUT, SCRIPT_NAME from ..utils.logger import logger from .apis import LOOKUP_NAMES @@ -116,7 +117,7 @@ def print_usage(self, file: Optional[IO[str]] = None) -> None: def error(self, message: str) -> NoReturn: logger.critical(message + "\n", error="Invalid command line", NAME=SCRIPT_NAME) self.print_usage(stderr) - raise ValueError(message) + raise ValueError(message.format(**ANSICodes.EmptyCodes)) def make_parser() -> MyParser: diff --git a/tests/test_6_main.py b/tests/test_6_main.py index 7a8ebe0..418d4f4 100644 --- a/tests/test_6_main.py +++ b/tests/test_6_main.py @@ -546,7 +546,7 @@ def get_value(self, res: SafeJSON) -> BibtexEntry: @pytest.mark.parametrize(("argv", "files_to_compare"), tests) def test_main(argv: List[str], files_to_compare: List[Tuple[str, str]]) -> None: - main(argv) + assert main(argv) == 0 FakeLookup.count = 0 day = datetime.today().strftime("%Y-%m-%d") for expected, generated in files_to_compare: @@ -637,10 +637,7 @@ def test_main(argv: List[str], files_to_compare: List[Tuple[str, str]]) -> None: @pytest.mark.parametrize(("argv", "exit_code"), exit_tests) def test_main_exit(argv: List[str], exit_code: int) -> None: - with pytest.raises(SystemExit) as test_exit: - main(argv) - assert test_exit.type is SystemExit - assert test_exit.value.code == exit_code + assert main(argv) == exit_code def test_promote() -> None: