From d7fd541d5d4017f59851c985e8967545602c157e Mon Sep 17 00:00:00 2001 From: Ivo Houbrechts Date: Wed, 14 Feb 2024 23:08:07 +0100 Subject: [PATCH] use pw to auto-install tools and fix type hints (#91) * use pw to auto-install tools and fix type hints * use pytest as testrunner * use ruff as formatter and linter + fix trivial lint errors * replace setuptools with flit and add build action * remove mypy type checking from pre-commit hook because it take too long * bumped version --------- Co-authored-by: ihoubr --- .github/workflows/actions.yml | 38 -- .github/workflows/build.yml | 38 ++ .gitignore | 5 +- .pre-commit-config.yaml | 31 +- CHANGELOG | 6 + README.rst | 31 +- pw | 202 ++++++++ pw.bat | 2 + pw.lock | 4 + pyproject.toml | 61 ++- setup.py | 31 -- src/pybreaker/__init__.py | 434 +++++------------ tests/__init__.py | 0 src/tests.py => tests/pybreaker_test.py | 619 ++++++++++++------------ tests/typechecks.py | 10 + tox.ini | 13 - 16 files changed, 803 insertions(+), 722 deletions(-) delete mode 100644 .github/workflows/actions.yml create mode 100644 .github/workflows/build.yml create mode 100755 pw create mode 100644 pw.bat create mode 100644 pw.lock delete mode 100755 setup.py create mode 100644 tests/__init__.py rename src/tests.py => tests/pybreaker_test.py (61%) create mode 100755 tests/typechecks.py delete mode 100644 tox.ini diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml deleted file mode 100644 index b6be8db..0000000 --- a/.github/workflows/actions.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: pybreaker test matrix - -on: - push: - branches: [main] - pull_request: - -jobs: - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Upgrade pip version - run: | - python -m pip install -U pip - - - name: Install wheel - run: | - python -m pip install wheel - - - name: Python versions - run: | - echo "Python ${{ matrix.python-version }}" - python --version - - - name: Run tests - run: | - python setup.py test diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..fed62fc --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,38 @@ +name: pybreaker test matrix + +on: + push: + branches: [main] + pull_request: + workflow_dispatch: + +jobs: + build: + name: Build Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + + steps: + - uses: actions/checkout@v3 + + - name: Cache pyprojectx + uses: actions/cache@v4 + with: + key: ${{ hashFiles('pyproject.toml') }}-${{ matrix.python-version }}-venvs + path: | + .pyprojectx + venv + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Test and build + run: | + ./pw lint + ./pw type-check + ./pw test diff --git a/.gitignore b/.gitignore index 8fdfda9..aadaa65 100644 --- a/.gitignore +++ b/.gitignore @@ -4,13 +4,16 @@ # tox .tox +.python-version + # Build directories build/ dist/ *.egg-info/ .eggs/ +.pyprojectx/ -# Pycharm directories +# Pycharm directories .idea .ropeproject/ *.iml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95f3968..85416b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,16 +1,25 @@ repos: - - repo: https://github.com/psf/black - rev: 22.6.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 hooks: - - id: black - - repo: https://github.com/pycqa/isort - rev: 5.10.1 + - id: trailing-whitespace + - id: end-of-file-fixer + - id: fix-byte-order-marker + + - repo: local hooks: - - id: isort - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.971 + - id: format + name: format + entry: pw format + language: script + types: [ python ] + pass_filenames: false + + - repo: local hooks: - - id: mypy + - id: lint + name: lint + entry: pw lint + language: script + types: [ python ] pass_filenames: false - additional_dependencies: - - "types-redis" diff --git a/CHANGELOG b/CHANGELOG index 57839da..bc3d11c 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,6 +1,12 @@ Changelog ========= +Version 1.2.0 (February ??, 2024) + +* Fixed type hints +* Improved building and testing +* Dropped support for Python 3.7 + Version 1.1.0 (January 09, 2024) * Added calling() method to CircuitBreaker, returning a context manager (Thanks @martijnthe) diff --git a/README.rst b/README.rst index f2372fe..47f9ef3 100644 --- a/README.rst +++ b/README.rst @@ -27,7 +27,7 @@ Features Requirements ------------ -* `Python`_ 3.7+ +* `Python`_ 3.8+ Installation @@ -38,13 +38,14 @@ PyBreaker from `PyPI`_:: $ pip install pybreaker -If you are a `Git`_ user, you might want to download the current development -version:: +If you are a `Git`_ user, you might want to install the current development +version in editable mode:: $ git clone git://github.com/danielfm/pybreaker.git $ cd pybreaker - $ python setup.py test - $ python setup.py install + $ # run tests (on windows omit ./) + $ ./pw test + $ pip install -e . Usage @@ -108,7 +109,7 @@ fail with: .. note:: - If you require multiple, independent CircuitBreakers and wish to store their states in Redis, it is essential to assign a ``unique namespace`` for each + If you require multiple, independent CircuitBreakers and wish to store their states in Redis, it is essential to assign a ``unique namespace`` for each CircuitBreaker instance. This can be achieved by specifying a distinct namespace parameter in the CircuitRedisStorage constructor. for example: .. code:: python @@ -318,6 +319,24 @@ change its current state: These properties and functions might and should be exposed to the operations staff somehow as they help them to detect problems in the system. +Contributing +------------- + +Run tests:: + + $ ./pw test + +Code formatting (black and isort) and linting (mypy) :: + + $ ./pw format + $ ./pw lint + +Above commands will automatically install the necessary tools inside *.pyprojectx* +and also install pre-commit hooks. + +List available commands:: + + $ ./pw -i .. _Python: http://python.org .. _Jython: http://jython.org diff --git a/pw b/pw new file mode 100755 index 0000000..7fcc088 --- /dev/null +++ b/pw @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 + +################################################################################## +# Pyprojectx wrapper script # +# https://github.com/pyprojectx/pyprojectx # +# # +# Copyright (c) 2021 Ivo Houbrechts # +# # +# Licensed under the MIT license # +################################################################################## +import argparse +import os +import subprocess +import sys +from pathlib import Path +from venv import EnvBuilder + +VERSION = "2.0.8" + +PYPROJECTX_INSTALL_DIR_ENV_VAR = "PYPROJECTX_INSTALL_DIR" +PYPROJECTX_PACKAGE_ENV_VAR = "PYPROJECTX_PACKAGE" +PYPROJECT_TOML = "pyproject.toml" +DEFAULT_INSTALL_DIR = ".pyprojectx" + +CYAN = "\033[96m" +BLUE = "\033[94m" +RED = "\033[91m" +RESET = "\033[0m" +if sys.platform.startswith("win"): + os.system("color") + + +def run(args): + try: + options = get_options(args) + pyprojectx_script = ensure_pyprojectx(options) + explicit_options = [] + if not options.toml: + explicit_options += ["--toml", str(options.toml_path)] + if not options.install_dir: + explicit_options += ["--install-dir", str(options.install_path)] + + subprocess.run([str(pyprojectx_script), *explicit_options, *args], check=True) + except subprocess.CalledProcessError as e: + raise SystemExit(e.returncode) from e + + +def get_options(args): + options = arg_parser().parse_args(args) + options.install_path = Path( + options.install_dir + or os.environ.get( + PYPROJECTX_INSTALL_DIR_ENV_VAR, + Path(__file__).with_name(DEFAULT_INSTALL_DIR), + ) + ) + options.toml_path = ( + Path(options.toml) if options.toml else Path(__file__).with_name(PYPROJECT_TOML) + ) + if os.environ.get(PYPROJECTX_PACKAGE_ENV_VAR): + options.version = "development" + options.pyprojectx_package = os.environ.get(PYPROJECTX_PACKAGE_ENV_VAR) + else: + options.version = VERSION + options.pyprojectx_package = f"pyprojectx~={VERSION}" + options.verbosity = ( + 0 if options.quiet or not options.verbosity else options.verbosity + ) + return options + + +def arg_parser(): + parser = argparse.ArgumentParser( + description="Execute commands or aliases defined in the [tool.pyprojectx] section of pyproject.toml. " + "Use the -i or --info option to see available tools and aliases.", + allow_abbrev=False, + ) + parser.add_argument("--version", action="version", version=VERSION) + parser.add_argument( + "--toml", + "-t", + action="store", + help="The toml config file. Defaults to 'pyproject.toml' in the same directory as the pw script.", + ) + parser.add_argument( + "--install-dir", + action="store", + help=f"The directory where all tools (including pyprojectx) are installed; defaults to the " + f"{PYPROJECTX_INSTALL_DIR_ENV_VAR} environment value if set, else '.pyprojectx' " + f"in the same directory as the invoked pw script", + ) + parser.add_argument( + "--force-install", + "-f", + action="store_true", + help="Force clean installation of the virtual environment used to run cmd, if any", + ) + parser.add_argument( + "--install-context", + action="store", + metavar="tool-context", + help="Install a tool context without actually running any command.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + dest="verbosity", + help="Give more output. This option is additive and can be used up to 2 times.", + ) + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Suppress output", + ) + parser.add_argument( + "--info", + "-i", + action="store_true", + help="Show the configuration details of a command instead of running it. " + "If no command is specified, a list with all available tools and aliases is shown.", + ) + parser.add_argument( + "--add", + action="store", + metavar="[context:],...", + help="Add one or more packages to a tool context. " + "If no context is specified, the packages are added to the main context. " + "Packages can be specified as in 'pip install', except that a ',' can't be used in the version specification.", + ) + parser.add_argument( + "--lock", + action="store_true", + help="Write all dependencies of all tool contexts to 'pw.lock' to guarantee reproducible outcomes.", + ) + parser.add_argument( + "--install-px", + action="store_true", + help="Install the px and pxg scripts in your home directory.", + ) + parser.add_argument( + "--upgrade", + action="store_true", + help="Print instructions to download the latest pyprojectx wrapper scripts.", + ) + parser.add_argument( + "command", + nargs=argparse.REMAINDER, + help="The command/alias with optional arguments to execute.", + ) + return parser + + +def ensure_pyprojectx(options): + env_builder = EnvBuilder(with_pip=True) + venv_dir = options.install_path.joinpath( + "pyprojectx", + f"{options.version}-py{sys.version_info.major}.{sys.version_info.minor}", + ) + env_context = env_builder.ensure_directories(venv_dir) + pyprojectx_script = Path(env_context.bin_path, "pyprojectx") + pyprojectx_exe = Path(env_context.bin_path, "pyprojectx.exe") + pip_cmd = [env_context.env_exe, "-m", "pip", "install"] + + if options.quiet: + out = subprocess.DEVNULL + pip_cmd.append("--quiet") + else: + out = sys.stderr + + if not pyprojectx_script.is_file() and not pyprojectx_exe.is_file(): + if not options.quiet: + print( + f"{CYAN}creating pyprojectx venv in {BLUE}{venv_dir}{RESET}", + file=sys.stderr, + ) + env_builder.create(venv_dir) + subprocess.run( + [*pip_cmd, "--upgrade", "pip"], + stdout=out, + check=True, + ) + + if not options.quiet: + print( + f"{CYAN}installing pyprojectx {BLUE}{options.version}: {options.pyprojectx_package} {RESET}", + file=sys.stderr, + ) + if options.version == "development": + if not options.quiet: + print( + f"{RED}WARNING: {options.pyprojectx_package} is installed in editable mode{RESET}", + file=sys.stderr, + ) + pip_cmd.append("-e") + subprocess.run([*pip_cmd, options.pyprojectx_package], stdout=out, check=True) + return pyprojectx_script + + +if __name__ == "__main__": + run(sys.argv[1:]) diff --git a/pw.bat b/pw.bat new file mode 100644 index 0000000..5a81b08 --- /dev/null +++ b/pw.bat @@ -0,0 +1,2 @@ +@echo off +python "%~dp0pw" %* diff --git a/pw.lock b/pw.lock new file mode 100644 index 0000000..8175553 --- /dev/null +++ b/pw.lock @@ -0,0 +1,4 @@ +[main] +requirements = ["PyYAML==6.0.1", "certifi==2024.2.2", "cfgv==3.4.0", "charset-normalizer==3.3.2", "distlib==0.3.8", "docutils==0.20.1", "filelock==3.13.1", "flit==3.9.0", "flit_core==3.9.0", "identify==2.5.33", "idna==3.6", "mypy-extensions==1.0.0", "mypy==1.8.0", "nodeenv==1.8.0", "platformdirs==4.2.0", "pre-commit==3.5.0", "px-utils==1.1.0", "requests==2.31.0", "ruff==0.2.0", "tomli==2.0.1", "tomli_w==1.0.0", "typing_extensions==4.9.0", "urllib3==2.2.0", "virtualenv==20.25.0"] +hash = "c05140614e209ba7f91f796ccdc08e52" +post-install = "pre-commit install" diff --git a/pyproject.toml b/pyproject.toml index 2f35c73..c2dd593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,64 @@ +[project] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +authors = [ + { name = "Daniel Fernandes Martins", email = "daniel.tritone@gmail.com" }, +] +readme = "README.rst" +urls = { Source = "http://github.com/danielfm/pybreaker" } +license = { file = "LICENSE" } +keywords = ["design", "pattern", "circuit", "breaker", "integration"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Libraries", +] +requires-python = ">=3.8" + +[project.optional-dependencies] +test = ["pytest", "mock", "tornado", "redis", "fakeredis", "types-mock", "types-redis"] + +[build-system] +requires = ["flit_core >=3.9"] +build-backend = "flit_core.buildapi" + +[tool.flit.sdist] +exclude = [".github", "tests", ".gitignore", ".pre-commit-config.yaml", "pw*"] + +[tool.pyprojectx] +[tool.pyprojectx.main] +requirements = ["pre-commit<=3.6", "ruff", "mypy", "px-utils", "flit"] +post-install = "pre-commit install" + +[tool.pyprojectx.venv] +dir = "@PROJECT_DIR/venv" +requirements = ["pytest", "mock", "tornado", "redis", "fakeredis", "types-mock", "types-redis", "-e ."] + +[tool.pyprojectx.aliases] +install = "pw@ --install-context venv" +format = ["ruff format", "ruff check --select I --fix"] +lint = ["@install", "ruff check src"] +type-check = ["@install", "mypy --python-executable venv/bin/python --no-incremental"] +fix-ruff = "ruff check --fix" +test = { cmd = "pytest", ctx = "venv" } +clean = "pxrm .eggs venv src/pybreaker.egg-info" + [tool.mypy] disallow_untyped_defs = true ignore_missing_imports = true no_implicit_optional = true show_error_codes = true -files = ["src/pybreaker.py"] +files = ["src", "tests/typechecks.py"] -[tool.isort] -profile = "black" +[tool.ruff] +line-length = 120 +[tool.ruff.lint] +select = ["ALL"] +fixable = ["ALL"] +ignore = ["ISC001", "ANN", "FA", "FBT", "D100", "D102", "D205", "D103", "D104", "D105", "D213", "D203", "T201", "TRY003", "EM102", "COM812", "S602", "S603", "S604", "S605", "S607", "S324"] +[tool.ruff.lint.per-file-ignores] +"src/*" = ["PLR0913", "TCH002", "D401", "D404", "SLF001", "BLE001", "TRY301", "DTZ001", "DTZ003", "SIM102", "TRY300"] diff --git a/setup.py b/setup.py deleted file mode 100755 index 139c3f3..0000000 --- a/setup.py +++ /dev/null @@ -1,31 +0,0 @@ -from setuptools import setup - -setup( - name="pybreaker", - version="1.1.0", - description="Python implementation of the Circuit Breaker pattern", - long_description=open("README.rst", "r").read(), - keywords=["design", "pattern", "circuit", "breaker", "integration"], - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Topic :: Software Development :: Libraries", - ], - platforms=["Any"], - license="BSD", - author="Daniel Fernandes Martins", - author_email="daniel.tritone@gmail.com", - url="http://github.com/danielfm/pybreaker", - package_data={"pybreaker": ["py.typed"]}, - package_dir={"": "src"}, - packages=["pybreaker"], - include_package_data=True, - install_requires=["typing_extensions>=3.10.0; python_version < '3.8'"], - zip_safe=False, - python_requires=">=3.7", - test_suite="tests", - tests_require=["mock", "fakeredis==2.14.1", "redis==4.5.5", "tornado"], -) diff --git a/src/pybreaker/__init__.py b/src/pybreaker/__init__.py index 3d8a199..fe68e52 100644 --- a/src/pybreaker/__init__.py +++ b/src/pybreaker/__init__.py @@ -1,10 +1,10 @@ -""" -Threadsafe pure-Python implementation of the Circuit Breaker pattern, described +"""Threadsafe pure-Python implementation of the Circuit Breaker pattern, described by Michael T. Nygard in his book 'Release It!'. For more information on this and other patterns and best practices, buy the book at https://pragprog.com/titles/mnee2/release-it-second-edition/ """ + from __future__ import annotations import calendar @@ -20,22 +20,17 @@ from typing import ( Any, Callable, + Generator, + Iterable, + Literal, NoReturn, Sequence, - Tuple, - Type, TypeVar, Union, cast, overload, - Iterable, ) -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - try: from tornado import gen @@ -74,8 +69,7 @@ class CircuitBreaker: - """ - More abstractly, circuit breakers exists to allow one subsystem to fail + """More abstractly, circuit breakers exists to allow one subsystem to fail without destroying the entire system. This is done by wrapping dangerous operations (typically integration points) @@ -88,15 +82,13 @@ def __init__( self, fail_max: int = 5, reset_timeout: float = 60, - exclude: Sequence[Type[ExceptionType]] | None = None, + exclude: Iterable[type[ExceptionType] | Callable[[Any], bool]] | None = None, listeners: Sequence[CBListenerType] | None = None, - state_storage: "CircuitBreakerStorage" | None = None, + state_storage: CircuitBreakerStorage | None = None, name: str | None = None, throw_new_error_on_trip: bool = True, ) -> None: - """ - Creates a new circuit breaker with the given parameters. - """ + """Create a new circuit breaker with the given parameters.""" self._lock = threading.RLock() self._state_storage = state_storage or CircuitMemoryStorage(STATE_CLOSED) self._state = self._create_new_state(self.current_state) @@ -112,41 +104,29 @@ def __init__( @property def fail_counter(self) -> int: - """ - Returns the current number of consecutive failures. - """ + """Return the current number of consecutive failures.""" return self._state_storage.counter @property def fail_max(self) -> int: - """ - Returns the maximum number of failures tolerated before the circuit is - opened. - """ + """Return the maximum number of failures tolerated before the circuit is opened.""" return self._fail_max @fail_max.setter def fail_max(self, number: int) -> None: - """ - Sets the maximum `number` of failures tolerated before the circuit is - opened. - """ + """Set the maximum `number` of failures tolerated before the circuit is opened.""" self._fail_max = number @property def reset_timeout(self) -> float: - """ - Once this circuit breaker is opened, it should remain opened until the + """Once this circuit breaker is opened, it should remain opened until the timeout period, in seconds, elapses. """ return self._reset_timeout @reset_timeout.setter def reset_timeout(self, timeout: float) -> None: - """ - Sets the `timeout` period, in seconds, this circuit breaker should be - kept open. - """ + """Set the `timeout` period, in seconds, this circuit breaker should be kept open.""" self._reset_timeout = timeout def _create_new_state( @@ -155,11 +135,8 @@ def _create_new_state( prev_state: CircuitBreakerState | None = None, notify: bool = False, ) -> CBStateType: - """ - Return state object from state string, i.e., - 'closed' -> - """ - state_map: dict[str, Type[CBStateType]] = { + """Return state object from state string, i.e., 'closed' -> .""" + state_map: dict[str, type[CBStateType]] = { STATE_CLOSED: CircuitClosedState, STATE_OPEN: CircuitOpenState, STATE_HALF_OPEN: CircuitHalfOpenState, @@ -167,15 +144,13 @@ def _create_new_state( try: cls = state_map[new_state] return cls(self, prev_state=prev_state, notify=notify) - except KeyError: + except KeyError as e: msg = "Unknown state {!r}, valid states: {}" - raise ValueError(msg.format(new_state, ", ".join(state_map))) + raise ValueError(msg.format(new_state, ", ".join(state_map))) from e @property def state(self) -> CBStateType: - """ - Update (if needed) and returns the cached state object. - """ + """Update (if needed) and returns the cached state object.""" # Ensure cached state is up-to-date if self.current_state != self._state.name: # If cached state is out-of-date, that means that it was likely @@ -187,60 +162,47 @@ def state(self) -> CBStateType: @state.setter def state(self, state_str: str) -> None: - """ - Set cached state and notify listeners of newly cached state. - """ + """Set cached state and notify listeners of newly cached state.""" with self._lock: - self._state = self._create_new_state( - state_str, prev_state=self._state, notify=True - ) + self._state = self._create_new_state(state_str, prev_state=self._state, notify=True) @property def current_state(self) -> str: - """ - Returns a string that identifies the state of the circuit breaker as + """Return a string that identifies the state of the circuit breaker as reported by the _state_storage. i.e., 'closed', 'open', 'half-open'. """ return self._state_storage.state @property - def excluded_exceptions(self) -> Tuple[Type[ExceptionType], ...]: - """ - Returns the list of excluded exceptions, e.g., exceptions that should + def excluded_exceptions( + self, + ) -> tuple[type[ExceptionType] | Callable[[Any], bool], ...]: + """Return the list of excluded exceptions, e.g., exceptions that should not be considered system errors by this circuit breaker. """ return tuple(self._excluded_exceptions) - def add_excluded_exception(self, exception: Type[ExceptionType]) -> None: - """ - Adds an exception to the list of excluded exceptions. - """ + def add_excluded_exception(self, exception: type[ExceptionType]) -> None: + """Add an exception to the list of excluded exceptions.""" with self._lock: self._excluded_exceptions.append(exception) - def add_excluded_exceptions(self, *exceptions: Type[ExceptionType]) -> None: - """ - Adds exceptions to the list of excluded exceptions. - """ + def add_excluded_exceptions(self, *exceptions: type[ExceptionType]) -> None: + """Add exceptions to the list of excluded exceptions.""" for exc in exceptions: self.add_excluded_exception(exc) - def remove_excluded_exception(self, exception: Type[ExceptionType]) -> None: - """ - Removes an exception from the list of excluded exceptions. - """ + def remove_excluded_exception(self, exception: type[ExceptionType]) -> None: + """Remove an exception from the list of excluded exceptions.""" with self._lock: self._excluded_exceptions.remove(exception) def _inc_counter(self) -> None: - """ - Increments the counter of failed calls. - """ + """Increment the counter of failed calls.""" self._state_storage.increment_counter() def is_system_error(self, exception: ExceptionType) -> bool: - """ - Returns whether the exception `exception` is considered a signal of + """Return whether the exception `exception` is considered a signal of system malfunction. Business exceptions should not cause this circuit breaker to open. """ @@ -255,29 +217,27 @@ def is_system_error(self, exception: ExceptionType) -> bool: return True def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: - """ - Calls `func` with the given `args` and `kwargs` according to the rules + """Call `func` with the given `args` and `kwargs` according to the rules implemented by the current state of this circuit breaker. """ with self._lock: return self.state.call(func, *args, **kwargs) @contextlib.contextmanager - def calling(self) -> Iterable[None]: - """ - Returns a context manager, enabling the circuit breaker to be used with a + def calling(self) -> Any: + """Return a context manager, enabling the circuit breaker to be used with a `with` statement. The block of code inside the `with` statement will be executed according to the rules implemented by the current state of this circuit breaker. """ - def _wrapper() -> None: + + def _wrapper() -> Generator: yield yield from self.call(_wrapper) def call_async(self, func, *args, **kwargs): # type: ignore[no-untyped-def] - """ - Calls async `func` with the given `args` and `kwargs` according to the rules + """Call async `func` with the given `args` and `kwargs` according to the rules implemented by the current state of this circuit breaker. Return a closure to prevent import errors when using without tornado present @@ -292,10 +252,7 @@ def wrapped(): # type: ignore[no-untyped-def] return wrapped() def open(self) -> bool: - """ - Opens the circuit, e.g., the following calls will immediately fail - until timeout elapses. - """ + """Open the circuit, e.g., the following calls will immediately fail until timeout elapses.""" with self._lock: self._state_storage.opened_at = datetime.utcnow() self.state = self._state_storage.state = STATE_OPEN # type: ignore[assignment] @@ -303,8 +260,7 @@ def open(self) -> bool: return self._throw_new_error_on_trip def half_open(self) -> None: - """ - Half-opens the circuit, e.g. lets the following call pass through and + """Half-open the circuit, e.g. lets the following call pass through and opens the circuit if the call fails (or closes the circuit if the call succeeds). """ @@ -312,15 +268,12 @@ def half_open(self) -> None: self.state = self._state_storage.state = STATE_HALF_OPEN # type: ignore[assignment] def close(self) -> None: - """ - Closes the circuit, e.g. lets the following calls execute as usual. - """ + """Close the circuit, e.g. lets the following calls execute as usual.""" with self._lock: self.state = self._state_storage.state = STATE_CLOSED # type: ignore[assignment] def __call__(self, *call_args: Any, **call_kwargs: bool) -> Callable: - """ - Returns a wrapper that calls the function `func` according to the rules + """Return a wrapper that calls the function `func` according to the rules implemented by the current state of this circuit breaker. Optionally takes the keyword argument `__pybreaker_call_coroutine`, @@ -329,7 +282,8 @@ def __call__(self, *call_args: Any, **call_kwargs: bool) -> Callable: call_async = call_kwargs.pop("__pybreaker_call_async", False) if call_async and not HAS_TORNADO_SUPPORT: - raise ImportError("No module named tornado") + message = "No module named tornado" + raise ImportError(message) def _outer_wrapper(func): # type: ignore[no-untyped-def] @wraps(func) @@ -345,123 +299,86 @@ def _inner_wrapper(*args, **kwargs): # type: ignore[no-untyped-def] return _outer_wrapper @property - def listeners(self) -> Tuple[CBListenerType, ...]: - """ - Returns the registered listeners as a tuple. - """ + def listeners(self) -> tuple[CBListenerType, ...]: + """Return the registered listeners as a tuple.""" return tuple(self._listeners) # type: ignore[arg-type] def add_listener(self, listener: CBListenerType) -> None: - """ - Registers a listener for this circuit breaker. - """ + """Register a listener for this circuit breaker.""" with self._lock: self._listeners.append(listener) # type: ignore[arg-type] def add_listeners(self, *listeners: CBListenerType) -> None: - """ - Registers listeners for this circuit breaker. - """ + """Register listeners for this circuit breaker.""" for listener in listeners: self.add_listener(listener) def remove_listener(self, listener: CBListenerType) -> None: - """ - Unregisters a listener of this circuit breaker. - """ + """Unregister a listener of this circuit breaker.""" with self._lock: self._listeners.remove(listener) # type: ignore[arg-type] @property def name(self) -> str | None: - """ - Returns the name of this circuit breaker. Useful for logging. - """ + """Return the name of this circuit breaker. Useful for logging.""" return self._name @name.setter def name(self, name: str) -> None: - """ - Set the name of this circuit breaker. - """ + """Set the name of this circuit breaker.""" self._name = name class CircuitBreakerStorage: - """ - Defines the underlying storage for a circuit breaker - the underlying + """Define the underlying storage for a circuit breaker - the underlying implementation should be in a subclass that overrides the method this class defines. """ def __init__(self, name: str) -> None: - """ - Creates a new instance identified by `name`. - """ + """Create a new instance identified by `name`.""" self._name = name @property def name(self) -> str: - """ - Returns a human friendly name that identifies this state. - """ + """Return a human friendly name that identifies this state.""" return self._name @property @abstractmethod def state(self) -> str: - """ - Override this method to retrieve the current circuit breaker state. - """ + """Override this method to retrieve the current circuit breaker state.""" @state.setter def state(self, state: str) -> None: - """ - Override this method to set the current circuit breaker state. - """ + """Override this method to set the current circuit breaker state.""" def increment_counter(self) -> None: - """ - Override this method to increase the failure counter by one. - """ + """Override this method to increase the failure counter by one.""" def reset_counter(self) -> None: - """ - Override this method to set the failure counter to zero. - """ + """Override this method to set the failure counter to zero.""" @property @abstractmethod def counter(self) -> int: - """ - Override this method to retrieve the current value of the failure counter. - """ + """Override this method to retrieve the current value of the failure counter.""" @property @abstractmethod def opened_at(self) -> datetime | None: - """ - Override this method to retrieve the most recent value of when the - circuit was opened. - """ + """Override this method to retrieve the most recent value of when the circuit was opened.""" @opened_at.setter def opened_at(self, datetime: datetime) -> None: - """ - Override this method to set the most recent value of when the circuit - was opened. - """ + """Override this method to set the most recent value of when the circuit was opened.""" class CircuitMemoryStorage(CircuitBreakerStorage): - """ - Implements a `CircuitBreakerStorage` in local memory. - """ + """Implement a `CircuitBreakerStorage` in local memory.""" def __init__(self, state: str) -> None: - """ - Creates a new instance with the given `state`. - """ + """Create a new instance with the given `state`.""" super().__init__("memory") self._fail_counter = 0 self._opened_at: datetime | None = None @@ -469,57 +386,40 @@ def __init__(self, state: str) -> None: @property def state(self) -> str: - """ - Returns the current circuit breaker state. - """ + """Return the current circuit breaker state.""" return self._state @state.setter def state(self, state: str) -> None: - """ - Set the current circuit breaker state to `state`. - """ + """Set the current circuit breaker state to `state`.""" self._state = state def increment_counter(self) -> None: - """ - Increases the failure counter by one. - """ + """Increase the failure counter by one.""" self._fail_counter += 1 def reset_counter(self) -> None: - """ - Sets the failure counter to zero. - """ + """Set the failure counter to zero.""" self._fail_counter = 0 @property def counter(self) -> int: - """ - Returns the current value of the failure counter. - """ + """Return the current value of the failure counter.""" return self._fail_counter @property def opened_at(self) -> datetime | None: - """ - Returns the most recent value of when the circuit was opened. - """ + """Return the most recent value of when the circuit was opened.""" return self._opened_at @opened_at.setter def opened_at(self, datetime: datetime) -> None: - """ - Sets the most recent value of when the circuit was opened to - `datetime`. - """ + """Set the most recent value of when the circuit was opened to `datetime`.""" self._opened_at = datetime class CircuitRedisStorage(CircuitBreakerStorage): - """ - Implements a `CircuitBreakerStorage` using redis. - """ + """Implement a `CircuitBreakerStorage` using redis.""" BASE_NAMESPACE = "pybreaker" @@ -533,18 +433,15 @@ def __init__( fallback_circuit_state: str = STATE_CLOSED, cluster_mode: bool = False, ): - """ - Creates a new instance with the given `state` and `redis` object. The + """Create a new instance with the given `state` and `redis` object. The redis object should be similar to pyredis' StrictRedis class. If there are any connection issues with redis, the `fallback_circuit_state` is used to determine the state of the circuit. """ - # Module does not exist, so this feature is not available if not HAS_REDIS_SUPPORT: - raise ImportError( - "CircuitRedisStorage can only be used if the required dependencies exist" - ) + message = "CircuitRedisStorage can only be used if the required dependencies exist" + raise ImportError(message) super().__init__("redis") @@ -562,8 +459,7 @@ def _initialize_redis_state(self, state: str) -> None: @property def state(self) -> str: - """ - Returns the current circuit breaker state. + """Return the current circuit breaker state. If the circuit breaker state on Redis is missing, re-initialize it with the fallback circuit state and reset the fail counter. @@ -571,9 +467,7 @@ def state(self) -> str: try: state_bytes: bytes | None = self._redis.get(self._namespace("state")) except RedisError: - self.logger.error( - "RedisError: falling back to default circuit state", exc_info=True - ) + self.logger.exception("RedisError: falling back to default circuit state") return self._fallback_circuit_state state = self._fallback_circuit_state @@ -588,65 +482,52 @@ def state(self) -> str: @state.setter def state(self, state: str) -> None: - """ - Set the current circuit breaker state to `state`. - """ + """Set the current circuit breaker state to `state`.""" try: self._redis.set(self._namespace("state"), str(state)) except RedisError: - self.logger.error("RedisError", exc_info=True) + self.logger.exception("RedisError") def increment_counter(self) -> None: - """ - Increases the failure counter by one. - """ + """Increase the failure counter by one.""" try: self._redis.incr(self._namespace("fail_counter")) except RedisError: - self.logger.error("RedisError", exc_info=True) + self.logger.exception("RedisError") def reset_counter(self) -> None: - """ - Sets the failure counter to zero. - """ + """Set the failure counter to zero.""" try: self._redis.set(self._namespace("fail_counter"), 0) except RedisError: - self.logger.error("RedisError", exc_info=True) + self.logger.exception("RedisError") @property def counter(self) -> int: - """ - Returns the current value of the failure counter. - """ + """Return the current value of the failure counter.""" try: value = self._redis.get(self._namespace("fail_counter")) if value: return int(value) - else: - return 0 + return 0 except RedisError: - self.logger.error("RedisError: Assuming no errors", exc_info=True) + self.logger.exception("RedisError: Assuming no errors") return 0 @property def opened_at(self) -> datetime | None: - """ - Returns a datetime object of the most recent value of when the circuit - was opened. - """ + """Returns a datetime object of the most recent value of when the circuit was opened.""" try: timestamp = self._redis.get(self._namespace("opened_at")) if timestamp: return datetime(*time.gmtime(int(timestamp))[:6]) except RedisError: - self.logger.error("RedisError", exc_info=True) + self.logger.exception("RedisError") return None @opened_at.setter def opened_at(self, now: datetime) -> None: - """ - Atomically sets the most recent value of when the circuit was opened + """Atomically set the most recent value of when the circuit was opened to `now`. Stored in redis as a simple integer of unix epoch time. To avoid timezone issues between different systems, the passed in datetime should be in UTC. @@ -655,7 +536,6 @@ def opened_at(self, now: datetime) -> None: key = self._namespace("opened_at") if self._cluster_mode: - current_value = self._redis.get(key) next_value = int(calendar.timegm(now.timetuple())) @@ -674,7 +554,7 @@ def set_if_greater(pipe: Pipeline[bytes]) -> None: self._redis.transaction(set_if_greater, key) except RedisError: - self.logger.error("RedisError", exc_info=True) + self.logger.exception("RedisError") def _namespace(self, key: str) -> str: name_parts = [self.BASE_NAMESPACE, key] @@ -685,30 +565,16 @@ def _namespace(self, key: str) -> str: class CircuitBreakerListener: - """ - Listener class used to plug code to a ``CircuitBreaker`` instance when - certain events happen. - """ + """Listener class used to plug code to a ``CircuitBreaker`` instance when certain events happen.""" - def before_call( - self, cb: CircuitBreaker, func: Callable[..., T], *args: Any, **kwargs: Any - ) -> None: - """ - This callback function is called before the circuit breaker `cb` calls - `fn`. - """ + def before_call(self, cb: CircuitBreaker, func: Callable[..., T], *args: Any, **kwargs: Any) -> None: + """This callback function is called before the circuit breaker `cb` calls `fn`.""" def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: - """ - This callback function is called when a function called by the circuit - breaker `cb` fails. - """ + """This callback function is called when a function called by the circuit breaker `cb` fails.""" def success(self, cb: CircuitBreaker) -> None: - """ - This callback function is called when a function called by the circuit - breaker `cb` succeeds. - """ + """This callback function is called when a function called by the circuit breaker `cb` succeeds.""" def state_change( self, @@ -716,36 +582,24 @@ def state_change( old_state: CircuitBreakerState | None, new_state: CircuitBreakerState, ) -> None: - """ - This callback function is called when the state of the circuit breaker - `cb` state changes. - """ + """This callback function is called when the state of the circuit breaker `cb` state changes.""" class CircuitBreakerState: - """ - Implements the behavior needed by all circuit breaker states. - """ + """Implement the behavior needed by all circuit breaker states.""" def __init__(self, cb: CircuitBreaker, name: str) -> None: - """ - Creates a new instance associated with the circuit breaker `cb` and - identified by `name`. - """ + """Create a new instance associated with the circuit breaker `cb` and identified by `name`.""" self._breaker: CircuitBreaker = cb self._name: str = name @property def name(self) -> str: - """ - Returns a human friendly name that identifies this state. - """ + """Return a human friendly name that identifies this state.""" return self._name @overload - def _handle_error( - self, exc: BaseException, reraise: Literal[True] = ... - ) -> NoReturn: + def _handle_error(self, exc: BaseException, reraise: Literal[True] = ...) -> NoReturn: ... @overload @@ -753,9 +607,7 @@ def _handle_error(self, exc: BaseException, reraise: Literal[False] = ...) -> No ... def _handle_error(self, exc: BaseException, reraise: bool = True) -> None: - """ - Handles a failed call to the guarded operation. - """ + """Handle a failed call to the guarded operation.""" if self._breaker.is_system_error(exc): self._breaker._inc_counter() for listener in self._breaker.listeners: @@ -768,17 +620,14 @@ def _handle_error(self, exc: BaseException, reraise: bool = True) -> None: raise exc def _handle_success(self) -> None: - """ - Handles a successful call to the guarded operation. - """ + """Handle a successful call to the guarded operation.""" self._breaker._state_storage.reset_counter() self.on_success() for listener in self._breaker.listeners: listener.success(self._breaker) def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: - """ - Calls `func` with the given `args` and `kwargs`, and updates the + """Calls `func` with the given `args` and `kwargs`, and updates the circuit breaker state according to the result. """ ret = None @@ -799,8 +648,7 @@ def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: return ret def call_async(self, func, *args: Any, **kwargs: Any): # type: ignore[no-untyped-def] - """ - Calls async `func` with the given `args` and `kwargs`, and updates the + """Call async `func` with the given `args` and `kwargs`, and updates the circuit breaker state according to the result. Return a closure to prevent import errors when using without tornado present @@ -840,27 +688,17 @@ def generator_call(self, wrapped_generator): # type: ignore[no-untyped-def] wrapped_generator.throw(e) def before_call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - """ - Override this method to be notified before a call to the guarded - operation is attempted. - """ + """Override this method to be notified before a call to the guarded operation is attempted.""" def on_success(self) -> None: - """ - Override this method to be notified when a call to the guarded - operation succeeds. - """ + """Override this method to be notified when a call to the guarded operation succeeds.""" def on_failure(self, exc: BaseException) -> None: - """ - Override this method to be notified when a call to the guarded - operation fails. - """ + """Override this method to be notified when a call to the guarded operation fails.""" class CircuitClosedState(CircuitBreakerState): - """ - In the normal "closed" state, the circuit breaker executes operations as + """In the normal "closed" state, the circuit breaker executes operations as usual. If the call succeeds, nothing happens. If it fails, however, the circuit breaker makes a note of the failure. @@ -874,9 +712,7 @@ def __init__( prev_state: CircuitBreakerState | None = None, notify: bool = False, ) -> None: - """ - Moves the given circuit breaker `cb` to the "closed" state. - """ + """Move the given circuit breaker `cb` to the "closed" state.""" super().__init__(cb, STATE_CLOSED) if notify: # We only reset the counter if notify is True, otherwise the CircuitBreaker @@ -889,23 +725,18 @@ def __init__( listener.state_change(self._breaker, prev_state, self) def on_failure(self, exc: BaseException) -> None: - """ - Moves the circuit breaker to the "open" state once the failures - threshold is reached. - """ + """Move the circuit breaker to the "open" state once the failures threshold is reached.""" if self._breaker._state_storage.counter >= self._breaker.fail_max: throw_new_error = self._breaker.open() if throw_new_error: error_msg = "Failures threshold reached, circuit breaker opened" raise CircuitBreakerError(error_msg).with_traceback(sys.exc_info()[2]) - else: - raise exc + raise exc class CircuitOpenState(CircuitBreakerState): - """ - When the circuit is "open", calls to the circuit breaker fail immediately, + """When the circuit is "open", calls to the circuit breaker fail immediately, without any attempt to execute the real operation. This is indicated by the ``CircuitBreakerError`` exception. @@ -919,17 +750,14 @@ def __init__( prev_state: CircuitBreakerState | None = None, notify: bool = False, ) -> None: - """ - Moves the given circuit breaker `cb` to the "open" state. - """ + """Move the given circuit breaker `cb` to the "open" state.""" super().__init__(cb, STATE_OPEN) if notify: for listener in self._breaker.listeners: listener.state_change(self._breaker, prev_state, self) def before_call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: - """ - After the timeout elapses, move the circuit breaker to the "half-open" + """After the timeout elapses, move the circuit breaker to the "half-open" state; otherwise, raises ``CircuitBreakerError`` without any attempt to execute the real operation. """ @@ -938,22 +766,18 @@ def before_call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: if opened_at and datetime.utcnow() < opened_at + timeout: error_msg = "Timeout not elapsed yet, circuit breaker still open" raise CircuitBreakerError(error_msg) - else: - self._breaker.half_open() - return self._breaker.call(func, *args, **kwargs) + self._breaker.half_open() + return self._breaker.call(func, *args, **kwargs) def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + """Delegate the call to before_call, if the time out is not elapsed it will throw an exception, otherwise we get + the results from the call performed after the state is switch to half-open. """ - Delegate the call to before_call, if the time out is not elapsed it will throw an exception, otherwise we get - the results from the call performed after the state is switch to half-open - """ - return self.before_call(func, *args, **kwargs) class CircuitHalfOpenState(CircuitBreakerState): - """ - In the "half-open" state, the next call to the circuit breaker is allowed + """In the "half-open" state, the next call to the circuit breaker is allowed to execute the dangerous operation. Should the call succeed, the circuit breaker resets and returns to the "closed" state. If this trial call fails, however, the circuit breaker returns to the "open" state until another @@ -966,35 +790,27 @@ def __init__( prev_state: CircuitBreakerState | None, notify: bool = False, ) -> None: - """ - Moves the given circuit breaker `cb` to the "half-open" state. - """ + """Move the given circuit breaker `cb` to the "half-open" state.""" super().__init__(cb, STATE_HALF_OPEN) if notify: for listener in self._breaker._listeners: listener.state_change(self._breaker, prev_state, self) def on_failure(self, exc: BaseException) -> NoReturn: - """ - Opens the circuit breaker. - """ + """Opens the circuit breaker.""" throw_new_error = self._breaker.open() if throw_new_error: error_msg = "Trial call failed, circuit breaker opened" raise CircuitBreakerError(error_msg).with_traceback(sys.exc_info()[2]) - else: - raise exc + raise exc def on_success(self) -> None: - """ - Closes the circuit breaker. - """ + """Closes the circuit breaker.""" self._breaker.close() class CircuitBreakerError(Exception): - """ - When calls to a service fails because the circuit is open, this error is + """When calls to a service fails because the circuit is open, this error is raised to allow the caller to handle this type of exception differently. """ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests.py b/tests/pybreaker_test.py similarity index 61% rename from src/tests.py rename to tests/pybreaker_test.py index 6bcfac5..918d55f 100755 --- a/src/tests.py +++ b/tests/pybreaker_test.py @@ -2,19 +2,21 @@ from contextlib import contextmanager from datetime import datetime from time import sleep +from unittest import mock -import mock -from tornado import gen, testing - +import pytest from pybreaker import * +from tornado import gen, testing -class CircuitBreakerStorageBasedTestCase(object): - """ - Mix in to test against different storage backings. Depends on +class CircuitBreakerStorageBasedTestCase: + """Mix in to test against different storage backings. Depends on `self.breaker` and `self.breaker_kwargs`. """ + def __init__(self): + self.breaker_kwargs = None + def test_successful_call(self): """CircuitBreaker: it should keep the circuit closed after a successful call. @@ -23,9 +25,9 @@ def test_successful_call(self): def func(): return True - self.assertTrue(self.breaker.call(func)) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.call(func) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "closed" def test_one_failed_call(self): """CircuitBreaker: it should keep the circuit closed after a few @@ -33,11 +35,12 @@ def test_one_failed_call(self): """ def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertEqual(1, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + assert self.breaker.fail_counter == 1 + assert self.breaker.current_state == "closed" def test_one_successful_call_after_failed_call(self): """CircuitBreaker: it should keep the circuit closed after few mixed @@ -48,145 +51,158 @@ def suc(): return True def err(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, err) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(err) + assert self.breaker.fail_counter == 1 - self.assertTrue(self.breaker.call(suc)) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.call(suc) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "closed" def test_several_failed_calls_setting_absent(self): """CircuitBreaker: it should open the circuit after many failures.""" self.breaker = CircuitBreaker(fail_max=3, **self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" def test_throw_new_error_on_trip_false(self): """CircuitBreaker: it should throw the original exception""" - self.breaker = CircuitBreaker( - fail_max=3, **self.breaker_kwargs, throw_new_error_on_trip=False - ) + self.breaker = CircuitBreaker(fail_max=3, **self.breaker_kwargs, throw_new_error_on_trip=False) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) # Circuit should be open - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" # Circuit should still be open and break - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" def test_throw_new_error_on_trip_true(self): """CircuitBreaker: it should throw a CircuitBreakerError exception""" - self.breaker = CircuitBreaker( - fail_max=3, **self.breaker_kwargs, throw_new_error_on_trip=True - ) + self.breaker = CircuitBreaker(fail_max=3, **self.breaker_kwargs, throw_new_error_on_trip=True) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" def test_traceback_in_circuitbreaker_error(self): """CircuitBreaker: it should open the circuit after many failures.""" self.breaker = CircuitBreaker(fail_max=3, **self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) # Circuit should open try: self.breaker.call(func) - self.fail("CircuitBreakerError should throw") - except CircuitBreakerError as e: + pytest.fail("CircuitBreakerError should throw") + except CircuitBreakerError: import traceback - self.assertIn("NotImplementedError", traceback.format_exc()) - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + assert "NotImplementedError" in traceback.format_exc() + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" def test_failed_call_after_timeout(self): """CircuitBreaker: it should half-open the circuit after timeout.""" - self.breaker = CircuitBreaker( - fail_max=3, reset_timeout=0.5, **self.breaker_kwargs - ) + self.breaker = CircuitBreaker(fail_max=3, reset_timeout=0.5, **self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertEqual("closed", self.breaker.current_state) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + assert self.breaker.current_state == "closed" # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(3, self.breaker.fail_counter) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 3 # Wait for timeout sleep(0.6) # Circuit should open again - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(4, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 4 + assert self.breaker.current_state == "open" def test_successful_after_timeout(self): """CircuitBreaker: it should close the circuit when a call succeeds after timeout. The successful function should only be called once. """ - self.breaker = CircuitBreaker( - fail_max=3, reset_timeout=1, **self.breaker_kwargs - ) + self.breaker = CircuitBreaker(fail_max=3, reset_timeout=1, **self.breaker_kwargs) suc = mock.MagicMock(return_value=True) def err(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, err) - self.assertRaises(NotImplementedError, self.breaker.call, err) - self.assertEqual("closed", self.breaker.current_state) + with pytest.raises(NotImplementedError): + self.breaker.call(err) + with pytest.raises(NotImplementedError): + self.breaker.call(err) + assert self.breaker.current_state == "closed" # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, err) - self.assertRaises(CircuitBreakerError, self.breaker.call, suc) - self.assertEqual(3, self.breaker.fail_counter) + with pytest.raises(CircuitBreakerError): + self.breaker.call(err) + with pytest.raises(CircuitBreakerError): + self.breaker.call(suc) + assert self.breaker.fail_counter == 3 # Wait for timeout, at least a second since redis rounds to a second sleep(2) # Circuit should close again - self.assertTrue(self.breaker.call(suc)) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) - self.assertEqual(1, suc.call_count) + assert self.breaker.call(suc) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "closed" + assert suc.call_count == 1 def test_failed_call_when_halfopen(self): """CircuitBreaker: it should open the circuit when a call fails in @@ -194,16 +210,17 @@ def test_failed_call_when_halfopen(self): """ def fun(): - raise NotImplementedError() + raise NotImplementedError self.breaker.half_open() - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("half-open", self.breaker.current_state) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "half-open" # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, fun) - self.assertEqual(1, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(fun) + assert self.breaker.fail_counter == 1 + assert self.breaker.current_state == "open" def test_successful_call_when_halfopen(self): """CircuitBreaker: it should close the circuit when a call succeeds in @@ -214,34 +231,38 @@ def fun(): return True self.breaker.half_open() - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("half-open", self.breaker.current_state) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "half-open" # Circuit should open - self.assertTrue(self.breaker.call(fun)) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.call(fun) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "closed" def test_close(self): """CircuitBreaker: it should allow the circuit to be closed manually.""" self.breaker = CircuitBreaker(fail_max=3, **self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) # Circuit should open - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertRaises(CircuitBreakerError, self.breaker.call, func) - self.assertEqual(3, self.breaker.fail_counter) - self.assertEqual("open", self.breaker.current_state) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + with pytest.raises(CircuitBreakerError): + self.breaker.call(func) + assert self.breaker.fail_counter == 3 + assert self.breaker.current_state == "open" # Circuit should close again self.breaker.close() - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.fail_counter == 0 + assert self.breaker.current_state == "closed" def test_transition_events(self): """CircuitBreaker: it should call the appropriate functions on every @@ -262,20 +283,18 @@ def state_change(self, cb, old_state, new_state): listener = Listener() self.breaker = CircuitBreaker(listeners=(listener,), **self.breaker_kwargs) - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.current_state == "closed" self.breaker.open() - self.assertEqual("open", self.breaker.current_state) + assert self.breaker.current_state == "open" self.breaker.half_open() - self.assertEqual("half-open", self.breaker.current_state) + assert self.breaker.current_state == "half-open" self.breaker.close() - self.assertEqual("closed", self.breaker.current_state) + assert self.breaker.current_state == "closed" - self.assertEqual( - "closed->open,open->half-open,half-open->closed,", listener.out - ) + assert listener.out == "closed->open,open->half-open,half-open->closed," def test_call_events(self): """CircuitBreaker: it should call the appropriate functions on every @@ -287,7 +306,7 @@ def suc(): return True def err(): - raise NotImplementedError() + raise NotImplementedError class Listener(CircuitBreakerListener): def __init__(self): @@ -309,9 +328,10 @@ def failure(self, cb, exc): listener = Listener() self.breaker = CircuitBreaker(listeners=(listener,), **self.breaker_kwargs) - self.assertTrue(self.breaker.call(suc)) - self.assertRaises(NotImplementedError, self.breaker.call, err) - self.assertEqual("-success-failure", listener.out) + assert self.breaker.call(suc) + with pytest.raises(NotImplementedError): + self.breaker.call(err) + assert listener.out == "-success-failure" def test_generator(self): """CircuitBreaker: it should inspect generator values.""" @@ -331,11 +351,13 @@ def err(value): e = err(True) next(e) - self.assertRaises(NotImplementedError, e.send, True) - self.assertEqual(1, self.breaker.fail_counter) - self.assertTrue(next(s)) - self.assertRaises((StopIteration, RuntimeError), lambda: next(s)) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + e.send(True) + assert self.breaker.fail_counter == 1 + assert next(s) + with pytest.raises((StopIteration, RuntimeError)): + next(s) + assert self.breaker.fail_counter == 0 def test_contextmanager(self): """CircuitBreaker: it should catch in a with statement""" @@ -346,90 +368,88 @@ class Foo: def wrapper(self): try: yield - except NotImplementedError as e: - raise ValueError() + except NotImplementedError: + raise ValueError def foo(self): with self.wrapper(): - raise NotImplementedError() + raise NotImplementedError try: Foo().foo() except ValueError as e: - self.assertTrue(isinstance(e, ValueError)) + assert isinstance(e, ValueError) -class CircuitBreakerConfigurationTestCase(object): - """ - Tests for the CircuitBreaker class. - """ +class CircuitBreakerConfigurationTestCase: + """Tests for the CircuitBreaker class.""" def test_default_state(self): """CircuitBreaker: it should get initial state from state_storage.""" for state in (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN): storage = CircuitMemoryStorage(state) breaker = CircuitBreaker(state_storage=storage) - self.assertEqual(breaker.state.name, state) + assert breaker.state.name == state def test_default_params(self): """CircuitBreaker: it should define smart defaults.""" - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual(60, self.breaker.reset_timeout) - self.assertEqual(5, self.breaker.fail_max) - self.assertEqual("closed", self.breaker.current_state) - self.assertEqual((), self.breaker.excluded_exceptions) - self.assertEqual((), self.breaker.listeners) - self.assertEqual("memory", self.breaker._state_storage.name) + assert self.breaker.fail_counter == 0 + assert self.breaker.reset_timeout == 60 + assert self.breaker.fail_max == 5 + assert self.breaker.current_state == "closed" + assert self.breaker.excluded_exceptions == () + assert self.breaker.listeners == () + assert self.breaker._state_storage.name == "memory" def test_new_with_custom_reset_timeout(self): """CircuitBreaker: it should support a custom reset timeout value.""" self.breaker = CircuitBreaker(reset_timeout=30) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual(30, self.breaker.reset_timeout) - self.assertEqual(5, self.breaker.fail_max) - self.assertEqual((), self.breaker.excluded_exceptions) - self.assertEqual((), self.breaker.listeners) - self.assertEqual("memory", self.breaker._state_storage.name) + assert self.breaker.fail_counter == 0 + assert self.breaker.reset_timeout == 30 + assert self.breaker.fail_max == 5 + assert self.breaker.excluded_exceptions == () + assert self.breaker.listeners == () + assert self.breaker._state_storage.name == "memory" def test_new_with_custom_fail_max(self): """CircuitBreaker: it should support a custom maximum number of failures. """ self.breaker = CircuitBreaker(fail_max=10) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual(60, self.breaker.reset_timeout) - self.assertEqual(10, self.breaker.fail_max) - self.assertEqual((), self.breaker.excluded_exceptions) - self.assertEqual((), self.breaker.listeners) - self.assertEqual("memory", self.breaker._state_storage.name) + assert self.breaker.fail_counter == 0 + assert self.breaker.reset_timeout == 60 + assert self.breaker.fail_max == 10 + assert self.breaker.excluded_exceptions == () + assert self.breaker.listeners == () + assert self.breaker._state_storage.name == "memory" def test_new_with_custom_excluded_exceptions(self): """CircuitBreaker: it should support a custom list of excluded exceptions. """ self.breaker = CircuitBreaker(exclude=[Exception]) - self.assertEqual(0, self.breaker.fail_counter) - self.assertEqual(60, self.breaker.reset_timeout) - self.assertEqual(5, self.breaker.fail_max) - self.assertEqual((Exception,), self.breaker.excluded_exceptions) - self.assertEqual((), self.breaker.listeners) - self.assertEqual("memory", self.breaker._state_storage.name) + assert self.breaker.fail_counter == 0 + assert self.breaker.reset_timeout == 60 + assert self.breaker.fail_max == 5 + assert (Exception,) == self.breaker.excluded_exceptions + assert self.breaker.listeners == () + assert self.breaker._state_storage.name == "memory" def test_fail_max_setter(self): """CircuitBreaker: it should allow the user to set a new value for 'fail_max'. """ - self.assertEqual(5, self.breaker.fail_max) + assert self.breaker.fail_max == 5 self.breaker.fail_max = 10 - self.assertEqual(10, self.breaker.fail_max) + assert self.breaker.fail_max == 10 def test_reset_timeout_setter(self): """CircuitBreaker: it should allow the user to set a new value for 'reset_timeout'. """ - self.assertEqual(60, self.breaker.reset_timeout) + assert self.breaker.reset_timeout == 60 self.breaker.reset_timeout = 30 - self.assertEqual(30, self.breaker.reset_timeout) + assert self.breaker.reset_timeout == 30 def test_call_with_no_args(self): """CircuitBreaker: it should be able to invoke functions with no-args.""" @@ -437,7 +457,7 @@ def test_call_with_no_args(self): def func(): return True - self.assertTrue(self.breaker.call(func)) + assert self.breaker.call(func) def test_call_with_args(self): """CircuitBreaker: it should be able to invoke functions with args.""" @@ -445,7 +465,7 @@ def test_call_with_args(self): def func(arg1, arg2): return [arg1, arg2] - self.assertEqual([42, "abc"], self.breaker.call(func, 42, "abc")) + assert [42, "abc"], self.breaker.call(func, 42 == "abc") def test_call_with_kwargs(self): """CircuitBreaker: it should be able to invoke functions with kwargs.""" @@ -453,7 +473,7 @@ def test_call_with_kwargs(self): def func(**kwargs): return kwargs - self.assertEqual({"a": 1, "b": 2}, self.breaker.call(func, a=1, b=2)) + assert {"a": 1, "b": 2}, self.breaker.call(func, a=1, b=2) @testing.gen_test def test_call_async_with_no_args(self): @@ -464,7 +484,7 @@ def func(): return True ret = yield self.breaker.call(func) - self.assertTrue(ret) + assert ret @testing.gen_test def test_call_async_with_args(self): @@ -475,7 +495,7 @@ def func(arg1, arg2): return [arg1, arg2] ret = yield self.breaker.call(func, 42, "abc") - self.assertEqual([42, "abc"], ret) + assert [42, "abc"] == ret @testing.gen_test def test_call_async_with_kwargs(self): @@ -486,21 +506,21 @@ def func(**kwargs): return kwargs ret = yield self.breaker.call(func, a=1, b=2) - self.assertEqual({"a": 1, "b": 2}, ret) + assert {"a": 1, "b": 2} == ret def test_add_listener(self): """CircuitBreaker: it should allow the user to add a listener at a later time. """ - self.assertEqual((), self.breaker.listeners) + assert self.breaker.listeners == () first = CircuitBreakerListener() self.breaker.add_listener(first) - self.assertEqual((first,), self.breaker.listeners) + assert (first,) == self.breaker.listeners second = CircuitBreakerListener() self.breaker.add_listener(second) - self.assertEqual((first, second), self.breaker.listeners) + assert (first, second) == self.breaker.listeners def test_add_listeners(self): """CircuitBreaker: it should allow the user to add listeners at a @@ -508,44 +528,48 @@ def test_add_listeners(self): """ first, second = CircuitBreakerListener(), CircuitBreakerListener() self.breaker.add_listeners(first, second) - self.assertEqual((first, second), self.breaker.listeners) + assert (first, second) == self.breaker.listeners def test_remove_listener(self): """CircuitBreaker: it should allow the user to remove a listener.""" first = CircuitBreakerListener() self.breaker.add_listener(first) - self.assertEqual((first,), self.breaker.listeners) + assert (first,) == self.breaker.listeners self.breaker.remove_listener(first) - self.assertEqual((), self.breaker.listeners) + assert self.breaker.listeners == () def test_excluded_exceptions(self): """CircuitBreaker: it should ignore specific exceptions.""" self.breaker = CircuitBreaker(exclude=[LookupError]) def err_1(): - raise NotImplementedError() + raise NotImplementedError def err_2(): - raise LookupError() + raise LookupError def err_3(): - raise KeyError() + raise KeyError - self.assertRaises(NotImplementedError, self.breaker.call, err_1) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(err_1) + assert self.breaker.fail_counter == 1 # LookupError is not considered a system error - self.assertRaises(LookupError, self.breaker.call, err_2) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(LookupError): + self.breaker.call(err_2) + assert self.breaker.fail_counter == 0 - self.assertRaises(NotImplementedError, self.breaker.call, err_1) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(err_1) + assert self.breaker.fail_counter == 1 # Should consider subclasses as well (KeyError is a subclass of # LookupError) - self.assertRaises(KeyError, self.breaker.call, err_3) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(KeyError): + self.breaker.call(err_3) + assert self.breaker.fail_counter == 0 def test_excluded_callable_exceptions(self): """CircuitBreaker: it should ignore specific exceptions that return true from a filtering callable.""" @@ -564,16 +588,19 @@ def err_2(): raise TestException("good") def err_3(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(TestException, self.breaker.call, err_1) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(TestException): + self.breaker.call(err_1) + assert self.breaker.fail_counter == 1 - self.assertRaises(TestException, self.breaker.call, err_2) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(TestException): + self.breaker.call(err_2) + assert self.breaker.fail_counter == 0 - self.assertRaises(NotImplementedError, self.breaker.call, err_3) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(err_3) + assert self.breaker.fail_counter == 1 def test_excluded_callable_and_types_exceptions(self): """CircuitBreaker: it should allow a mix of exclusions that includes both filter functions and types.""" @@ -592,55 +619,55 @@ def err_2(): raise TestException("good") def err_3(): - raise NotImplementedError() + raise NotImplementedError def err_4(): - raise LookupError() + raise LookupError - self.assertRaises(TestException, self.breaker.call, err_1) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(TestException): + self.breaker.call(err_1) + assert self.breaker.fail_counter == 1 - self.assertRaises(TestException, self.breaker.call, err_2) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(TestException): + self.breaker.call(err_2) + assert self.breaker.fail_counter == 0 - self.assertRaises(NotImplementedError, self.breaker.call, err_3) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(err_3) + assert self.breaker.fail_counter == 1 - self.assertRaises(LookupError, self.breaker.call, err_4) - self.assertEqual(0, self.breaker.fail_counter) + with pytest.raises(LookupError): + self.breaker.call(err_4) + assert self.breaker.fail_counter == 0 def test_add_excluded_exception(self): """CircuitBreaker: it should allow the user to exclude an exception at a later time. """ - self.assertEqual((), self.breaker.excluded_exceptions) + assert self.breaker.excluded_exceptions == () self.breaker.add_excluded_exception(NotImplementedError) - self.assertEqual((NotImplementedError,), self.breaker.excluded_exceptions) + assert (NotImplementedError,) == self.breaker.excluded_exceptions self.breaker.add_excluded_exception(Exception) - self.assertEqual( - (NotImplementedError, Exception), self.breaker.excluded_exceptions - ) + assert (NotImplementedError, Exception) == self.breaker.excluded_exceptions def test_add_excluded_exceptions(self): """CircuitBreaker: it should allow the user to exclude exceptions at a later time. """ self.breaker.add_excluded_exceptions(NotImplementedError, Exception) - self.assertEqual( - (NotImplementedError, Exception), self.breaker.excluded_exceptions - ) + assert (NotImplementedError, Exception) == self.breaker.excluded_exceptions def test_remove_excluded_exception(self): """CircuitBreaker: it should allow the user to remove an excluded exception. """ self.breaker.add_excluded_exception(NotImplementedError) - self.assertEqual((NotImplementedError,), self.breaker.excluded_exceptions) + assert (NotImplementedError,) == self.breaker.excluded_exceptions self.breaker.remove_excluded_exception(NotImplementedError) - self.assertEqual((), self.breaker.excluded_exceptions) + assert self.breaker.excluded_exceptions == () def test_decorator(self): """CircuitBreaker: it should be a decorator.""" @@ -653,18 +680,19 @@ def suc(value): @self.breaker def err(value): "Docstring" - raise NotImplementedError() + raise NotImplementedError - self.assertEqual("Docstring", suc.__doc__) - self.assertEqual("Docstring", err.__doc__) - self.assertEqual("suc", suc.__name__) - self.assertEqual("err", err.__name__) + assert suc.__doc__ == "Docstring" + assert err.__doc__ == "Docstring" + assert suc.__name__ == "suc" + assert err.__name__ == "err" - self.assertRaises(NotImplementedError, err, True) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + err(True) + assert self.breaker.fail_counter == 1 - self.assertTrue(suc(True)) - self.assertEqual(0, self.breaker.fail_counter) + assert suc(True) + assert self.breaker.fail_counter == 0 @testing.gen_test def test_decorator_call_future(self): @@ -680,25 +708,25 @@ def suc(value): @gen.coroutine def err(value): "Docstring" - raise NotImplementedError() + raise NotImplementedError - self.assertEqual("Docstring", suc.__doc__) - self.assertEqual("Docstring", err.__doc__) - self.assertEqual("suc", suc.__name__) - self.assertEqual("err", err.__name__) + assert suc.__doc__ == "Docstring" + assert err.__doc__ == "Docstring" + assert suc.__name__ == "suc" + assert err.__name__ == "err" - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): yield err(True) - self.assertEqual(1, self.breaker.fail_counter) + assert self.breaker.fail_counter == 1 ret = yield suc(True) - self.assertTrue(ret) - self.assertEqual(0, self.breaker.fail_counter) + assert ret + assert self.breaker.fail_counter == 0 @mock.patch("pybreaker.HAS_TORNADO_SUPPORT", False) def test_no_tornado_raises(self): - with self.assertRaises(ImportError): + with pytest.raises(ImportError): def func(): return True @@ -711,11 +739,11 @@ def test_name(self): """ name = "test_breaker" self.breaker = CircuitBreaker(name=name) - self.assertEqual(self.breaker.name, name) + assert self.breaker.name == name name = "breaker_test" self.breaker.name = name - self.assertEqual(self.breaker.name, name) + assert self.breaker.name == name class CircuitBreakerTestCase( @@ -723,9 +751,7 @@ class CircuitBreakerTestCase( CircuitBreakerStorageBasedTestCase, CircuitBreakerConfigurationTestCase, ): - """ - Tests for the CircuitBreaker class. - """ + """Tests for the CircuitBreaker class.""" def setUp(self): super(CircuitBreakerTestCase, self).setUp() @@ -733,7 +759,7 @@ def setUp(self): self.breaker = CircuitBreaker() def test_create_new_state__bad_state(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.breaker._create_new_state("foo") @mock.patch("pybreaker.CircuitOpenState") @@ -756,8 +782,8 @@ def test_failure_count_not_reset_during_creation(self): storage.increment_counter() breaker = CircuitBreaker(state_storage=storage) - self.assertEqual(breaker.state.name, state) - self.assertEqual(breaker.fail_counter, 1) + assert breaker.state.name == state + assert breaker.fail_counter == 1 def test_state_opened_at_not_reset_during_creation(self): for state in (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN): @@ -766,8 +792,8 @@ def test_state_opened_at_not_reset_during_creation(self): storage.opened_at = now breaker = CircuitBreaker(state_storage=storage) - self.assertEqual(breaker.state.name, state) - self.assertEqual(storage.opened_at, now) + assert breaker.state.name == state + assert storage.opened_at == now import logging @@ -776,18 +802,12 @@ def test_state_opened_at_not_reset_during_creation(self): from redis.exceptions import RedisError -class CircuitBreakerRedisTestCase( - unittest.TestCase, CircuitBreakerStorageBasedTestCase -): - """ - Tests for the CircuitBreaker class. - """ +class CircuitBreakerRedisTestCase(unittest.TestCase, CircuitBreakerStorageBasedTestCase): + """Tests for the CircuitBreaker class.""" def setUp(self): self.redis = fakeredis.FakeStrictRedis() - self.breaker_kwargs = { - "state_storage": CircuitRedisStorage("closed", self.redis) - } + self.breaker_kwargs = {"state_storage": CircuitRedisStorage("closed", self.redis)} self.breaker = CircuitBreaker(**self.breaker_kwargs) def tearDown(self): @@ -795,29 +815,24 @@ def tearDown(self): def test_namespace(self): self.redis.flushall() - self.breaker_kwargs = { - "state_storage": CircuitRedisStorage( - "closed", self.redis, namespace="my_app" - ) - } + self.breaker_kwargs = {"state_storage": CircuitRedisStorage("closed", self.redis, namespace="my_app")} self.breaker = CircuitBreaker(**self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) + with pytest.raises(NotImplementedError): + self.breaker.call(func) keys = self.redis.keys() - self.assertEqual(2, len(keys)) - self.assertTrue(keys[0].decode("utf-8").startswith("my_app")) - self.assertTrue(keys[1].decode("utf-8").startswith("my_app")) + assert len(keys) == 2 + assert keys[0].decode("utf-8").startswith("my_app") + assert keys[1].decode("utf-8").startswith("my_app") def test_fallback_state(self): logger = logging.getLogger("pybreaker") logger.setLevel(logging.FATAL) self.breaker_kwargs = { - "state_storage": CircuitRedisStorage( - "closed", self.redis, fallback_circuit_state="open" - ) + "state_storage": CircuitRedisStorage("closed", self.redis, fallback_circuit_state="open") } self.breaker = CircuitBreaker(**self.breaker_kwargs) @@ -826,36 +841,33 @@ def func(k): with mock.patch.object(self.redis, "get", new=func): state = self.breaker.state - self.assertEqual("open", state.name) + assert state.name == "open" def test_missing_state(self): """CircuitBreakerRedis: If state on Redis is missing, it should set the fallback circuit state and reset the fail counter to 0. """ self.breaker_kwargs = { - "state_storage": CircuitRedisStorage( - "closed", self.redis, fallback_circuit_state="open" - ) + "state_storage": CircuitRedisStorage("closed", self.redis, fallback_circuit_state="open") } self.breaker = CircuitBreaker(**self.breaker_kwargs) def func(): - raise NotImplementedError() + raise NotImplementedError - self.assertRaises(NotImplementedError, self.breaker.call, func) - self.assertEqual(1, self.breaker.fail_counter) + with pytest.raises(NotImplementedError): + self.breaker.call(func) + assert self.breaker.fail_counter == 1 with mock.patch.object(self.redis, "get", new=lambda k: None): state = self.breaker.state - self.assertEqual("open", state.name) - self.assertEqual(0, self.breaker.fail_counter) + assert state.name == "open" + assert self.breaker.fail_counter == 0 def test_cluster_mode(self): self.redis.flushall() - storage = CircuitRedisStorage( - STATE_OPEN, self.redis, namespace="my_app", cluster_mode=True - ) + storage = CircuitRedisStorage(STATE_OPEN, self.redis, namespace="my_app", cluster_mode=True) breaker_kwargs = {"state_storage": storage} now = datetime.now() @@ -865,8 +877,8 @@ def test_cluster_mode(self): opened_at = storage.opened_at.strftime("%Y-%m-%d-%H:%M:%S") breaker = CircuitBreaker(**breaker_kwargs) - self.assertEqual(breaker.state.name, STATE_OPEN) - self.assertEqual(opened_at, now_str) + assert breaker.state.name == STATE_OPEN + assert opened_at == now_str import threading @@ -874,38 +886,33 @@ def test_cluster_mode(self): class CircuitBreakerThreadsTestCase(unittest.TestCase): - """ - Tests to reproduce common synchronization errors on CircuitBreaker class. - """ + """Tests to reproduce common synchronization errors on CircuitBreaker class.""" def setUp(self): self.breaker = CircuitBreaker(fail_max=3000, reset_timeout=1) def _start_threads(self, target, n): - """ - Starts `n` threads that calls `target` and waits for them to finish. - """ + """Starts `n` threads that calls `target` and waits for them to finish.""" threads = [threading.Thread(target=target) for i in range(n)] [t.start() for t in threads] [t.join() for t in threads] def _mock_function(self, obj, func): - """ - Replaces a bounded function in `self.breaker` by another. - """ + """Replaces a bounded function in `self.breaker` by another.""" setattr(obj, func.__name__, MethodType(func, self.breaker)) def test_fail_thread_safety(self): """CircuitBreaker: it should compute a failed call atomically to avoid race conditions. """ + # Create a specific exception to avoid masking other errors class SpecificException(Exception): pass @self.breaker def err(): - raise SpecificException() + raise SpecificException def trigger_error(): for n in range(500): @@ -921,7 +928,7 @@ def _inc_counter(self): self._mock_function(self.breaker, _inc_counter) self._start_threads(trigger_error, 3) - self.assertEqual(1500, self.breaker.fail_counter) + assert self.breaker.fail_counter == 1500 def test_success_thread_safety(self): """CircuitBreaker: it should compute a successful call atomically @@ -946,7 +953,7 @@ def success(self, cb): self.breaker.add_listener(SuccessListener()) self._start_threads(trigger_success, 3) - self.assertEqual(1500, self.breaker._success_counter) + assert self.breaker._success_counter == 1500 def test_half_open_thread_safety(self): """CircuitBreaker: it should allow only one trial call when the @@ -959,7 +966,7 @@ def test_half_open_thread_safety(self): @self.breaker def err(): - raise Exception() + raise Exception def trigger_failure(): try: @@ -982,7 +989,7 @@ def state_change(self, cb, old_state, new_state): self.breaker.add_listener(state_listener) self._start_threads(trigger_failure, 5) - self.assertEqual(1, state_listener._count) + assert state_listener._count == 1 def test_fail_max_thread_safety(self): """CircuitBreaker: it should not allow more failed calls than @@ -991,7 +998,7 @@ def test_fail_max_thread_safety(self): @self.breaker def err(): - raise Exception() + raise Exception def trigger_error(): for i in range(2000): @@ -1006,12 +1013,11 @@ def before_call(self, cb, func, *args, **kwargs): self.breaker.add_listener(SleepListener()) self._start_threads(trigger_error, 3) - self.assertEqual(self.breaker.fail_max, self.breaker.fail_counter) + assert self.breaker.fail_max == self.breaker.fail_counter class CircuitBreakerRedisConcurrencyTestCase(unittest.TestCase): - """ - Tests to reproduce common concurrency between different machines + """Tests to reproduce common concurrency between different machines connecting to redis. This is simulated locally using threads. """ @@ -1028,30 +1034,27 @@ def tearDown(self): self.redis.flushall() def _start_threads(self, target, n): - """ - Starts `n` threads that calls `target` and waits for them to finish. - """ + """Starts `n` threads that calls `target` and waits for them to finish.""" threads = [threading.Thread(target=target) for i in range(n)] [t.start() for t in threads] [t.join() for t in threads] def _mock_function(self, obj, func): - """ - Replaces a bounded function in `self.breaker` by another. - """ + """Replaces a bounded function in `self.breaker` by another.""" setattr(obj, func.__name__, MethodType(func, self.breaker)) def test_fail_thread_safety(self): """CircuitBreaker: it should compute a failed call atomically to avoid race conditions. """ + # Create a specific exception to avoid masking other errors class SpecificException(Exception): pass @self.breaker def err(): - raise SpecificException() + raise SpecificException def trigger_error(): for n in range(500): @@ -1066,7 +1069,7 @@ def _inc_counter(self): self._mock_function(self.breaker, _inc_counter) self._start_threads(trigger_error, 3) - self.assertEqual(1500, self.breaker.fail_counter) + assert self.breaker.fail_counter == 1500 def test_success_thread_safety(self): """CircuitBreaker: it should compute a successful call atomically @@ -1091,7 +1094,7 @@ def success(self, cb): self.breaker.add_listener(SuccessListener()) self._start_threads(trigger_success, 3) - self.assertEqual(1500, self.breaker._success_counter) + assert self.breaker._success_counter == 1500 def test_half_open_thread_safety(self): """CircuitBreaker: it should allow only one trial call when the @@ -1104,7 +1107,7 @@ def test_half_open_thread_safety(self): @self.breaker def err(): - raise Exception() + raise Exception def trigger_failure(): try: @@ -1127,7 +1130,7 @@ def state_change(self, cb, old_state, new_state): self.breaker.add_listener(state_listener) self._start_threads(trigger_failure, 5) - self.assertEqual(1, state_listener._count) + assert state_listener._count == 1 def test_fail_max_thread_safety(self): """CircuitBreaker: it should not allow more failed calls than 'fail_max' @@ -1140,7 +1143,7 @@ def test_fail_max_thread_safety(self): @self.breaker def err(): - raise Exception() + raise Exception def trigger_error(): for i in range(2000): @@ -1156,18 +1159,14 @@ def before_call(self, cb, func, *args, **kwargs): self.breaker.add_listener(SleepListener()) num_threads = 3 self._start_threads(trigger_error, num_threads) - self.assertTrue(self.breaker.fail_counter < self.breaker.fail_max + num_threads) + assert self.breaker.fail_counter < self.breaker.fail_max + num_threads class CircuitBreakerContextManagerTestCase(unittest.TestCase): - """ - Tests for the CircuitBreaker class, when used as a context manager. - """ + """Tests for the CircuitBreaker class, when used as a context manager.""" def test_calling(self): - """ - Test that the CircuitBreaker calling() API returns a context manager and works as expected. - """ + """Test that the CircuitBreaker calling() API returns a context manager and works as expected.""" class TestError(Exception): pass @@ -1185,8 +1184,8 @@ def _do_succeed(): self.assertRaises(TestError, _do_raise) self.assertRaises(CircuitBreakerError, _do_raise) - self.assertEqual(2, breaker.fail_counter) - self.assertEqual("open", breaker.current_state) + assert breaker.fail_counter == 2 + assert breaker.current_state == "open" # Still fails while circuit breaker is open: self.assertRaises(CircuitBreakerError, _do_succeed) @@ -1196,8 +1195,8 @@ def _do_succeed(): _do_succeed() mock_fn.assert_called_once() - self.assertEqual(0, breaker.fail_counter) - self.assertEqual("closed", breaker.current_state) + assert breaker.fail_counter == 0 + assert breaker.current_state == "closed" if __name__ == "__main__": diff --git a/tests/typechecks.py b/tests/typechecks.py new file mode 100755 index 0000000..1c226d5 --- /dev/null +++ b/tests/typechecks.py @@ -0,0 +1,10 @@ +from pybreaker import STATE_CLOSED, CircuitBreaker, CircuitMemoryStorage + +# issue #90: incorrect typing for exclude argument +# this should not give errors in mypy +CircuitBreaker( + fail_max=1, + reset_timeout=1000, + exclude=[lambda e: isinstance(e, RuntimeError)], + state_storage=CircuitMemoryStorage(STATE_CLOSED), +) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index a1d2005..0000000 --- a/tox.ini +++ /dev/null @@ -1,13 +0,0 @@ -# tox (https://tox.readthedocs.io/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. - -[tox] -envlist = py36, py37, py38, py39, py310, py311 - -[testenv] -deps = - -commands = - python setup.py test